Unverified Commit f946659f authored by EdalatiAli's avatar EdalatiAli Committed by GitHub
Browse files

[Bugfix] Fix W4A8_FP8 MoE tp>1 correctness and view() TypeError (#40310)


Signed-off-by: default avatarEdalatiAli <aliedalati@cohere.com>
Co-authored-by: default avatargemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: default avatarmergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
parent f90aa446
...@@ -198,11 +198,15 @@ class CompressedTensorsW4A8Fp8MoEMethod(CompressedTensorsMoEMethod): ...@@ -198,11 +198,15 @@ class CompressedTensorsW4A8Fp8MoEMethod(CompressedTensorsMoEMethod):
# encode and reorder weight tensors, and get the layout to pass to # encode and reorder weight tensors, and get the layout to pass to
# the grouped gemm kernel. `b_strides1/2` specifies the entire layout # the grouped gemm kernel. `b_strides1/2` specifies the entire layout
convert_packed_uint4b8_to_signed_int4_inplace(layer.w13_weight_packed) convert_packed_uint4b8_to_signed_int4_inplace(layer.w13_weight_packed)
# mirror the sync in CutlassW4A8LinearKernel; required for tp>1 correctness
torch.accelerator.synchronize()
w13_weight_shuffled, self.b_strides1 = ( w13_weight_shuffled, self.b_strides1 = (
ops.cutlass_encode_and_reorder_int4b_grouped(layer.w13_weight_packed) ops.cutlass_encode_and_reorder_int4b_grouped(layer.w13_weight_packed)
) )
replace_parameter(layer, "w13_weight_packed", w13_weight_shuffled) replace_parameter(layer, "w13_weight_packed", w13_weight_shuffled)
convert_packed_uint4b8_to_signed_int4_inplace(layer.w2_weight_packed) convert_packed_uint4b8_to_signed_int4_inplace(layer.w2_weight_packed)
# mirror the sync in CutlassW4A8LinearKernel; required for tp>1 correctness
torch.accelerator.synchronize()
w2_weight_shuffled, self.b_strides2 = ( w2_weight_shuffled, self.b_strides2 = (
ops.cutlass_encode_and_reorder_int4b_grouped(layer.w2_weight_packed) ops.cutlass_encode_and_reorder_int4b_grouped(layer.w2_weight_packed)
) )
......
...@@ -818,7 +818,7 @@ def convert_bf16_scales_to_fp8( ...@@ -818,7 +818,7 @@ def convert_bf16_scales_to_fp8(
# restore original shape # restore original shape
fp8_scales = fp8_scales.view(orig_shape) fp8_scales = fp8_scales.view(orig_shape)
chan_scales = chan_scales.view(orig_shape[:-1], -1) chan_scales = chan_scales.view(*orig_shape[:-1], -1)
return fp8_scales, chan_scales return fp8_scales, chan_scales
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment