Unverified Commit c0f01191 authored by Krish Gupta's avatar Krish Gupta Committed by GitHub
Browse files

[Bugfix] opcheck false mutation error in rms_norm_per_block_quant (#36688) (#36779)


Signed-off-by: default avatarKrish Gupta <krishom70@gmail.com>
parent e6ae4b1b
...@@ -286,6 +286,15 @@ void rms_norm_per_block_quant(torch::Tensor& out, torch::Tensor const& input, ...@@ -286,6 +286,15 @@ void rms_norm_per_block_quant(torch::Tensor& out, torch::Tensor const& input,
"Outer scale stride must be 1 when scales are not transposed"); "Outer scale stride must be 1 when scales are not transposed");
} }
int64_t hidden_size = input.size(-1);
TORCH_CHECK(hidden_size > 0 && hidden_size % group_size == 0,
"hidden_size must be a positive multiple of group_size");
int64_t num_tokens = input.numel() / hidden_size;
int64_t num_groups = hidden_size / group_size;
TORCH_CHECK(scales.numel() >= num_tokens * num_groups,
"scales buffer too small: need ", num_tokens * num_groups,
" elements, got ", scales.numel());
rms_norm_per_block_quant_dispatch(out, input, weight, scales, group_size, rms_norm_per_block_quant_dispatch(out, input, weight, scales, group_size,
var_epsilon, scale_ub, residual, var_epsilon, scale_ub, residual,
is_scale_transposed); is_scale_transposed);
......
...@@ -280,21 +280,22 @@ def test_rms_norm( ...@@ -280,21 +280,22 @@ def test_rms_norm(
assert torch.allclose(ref_residual, ops_residual) assert torch.allclose(ref_residual, ops_residual)
output = torch.empty(x.shape, dtype=quant_dtype, device=x.device) output = torch.empty(x.shape, dtype=quant_dtype, device=x.device)
scales = torch.empty(
(x.numel() // x.shape[-1], 1), device=x.device, dtype=torch.float32
)
if group_size is None: if group_size is None:
scales = torch.empty(
(x.numel() // x.shape[-1], 1), device=x.device, dtype=torch.float32
)
opcheck( opcheck(
torch.ops._C.rms_norm_dynamic_per_token_quant, torch.ops._C.rms_norm_dynamic_per_token_quant,
(output, x, layer.weight, scales, 1e-5, scale_ub, residual), (output, x, layer.weight, scales, 1e-5, scale_ub, residual),
) )
else: else:
# TODO(luka/eliza) opcheck is broken? assert hidden_size % group_size[1] == 0
# Somehow the cloned args are getting mutated in-place, num_groups = hidden_size // group_size[1]
# which causes the opcheck to fail. scales = torch.empty(
# https://github.com/vllm-project/vllm/issues/36688 (num_groups, num_tokens),
return device=x.device,
dtype=torch.float32,
).transpose(0, 1)
opcheck( opcheck(
torch.ops._C.rms_norm_per_block_quant, torch.ops._C.rms_norm_per_block_quant,
( (
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment