Unverified Commit d6da8a8f authored by Richard Barnes's avatar Richard Barnes Committed by GitHub
Browse files

[Bugfix] Fix `numel()` downcast in fused_layernorm_dynamic_per_token_quant.cu (#17316)

parent b4ac4fa0
......@@ -96,7 +96,7 @@ void rms_norm_dynamic_per_token_quant_dispatch(
std::optional<at::Tensor> const& scale_ub,
std::optional<at::Tensor>& residual) {
int32_t hidden_size = input.size(-1);
int32_t num_tokens = input.numel() / hidden_size;
auto num_tokens = input.numel() / hidden_size;
dim3 grid(num_tokens);
dim3 block(std::min(hidden_size, 1024));
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment