Unverified Commit 325f679f authored by Robert Shaw's avatar Robert Shaw Committed by GitHub
Browse files

[BugFix] Fix Torch.Compile For DeepSeek (#12594)


Co-authored-by: default avatarsimon-mo <xmo@berkeley.edu>
parent e3f7ff65
...@@ -245,20 +245,24 @@ class Fp8LinearMethod(LinearMethodBase): ...@@ -245,20 +245,24 @@ class Fp8LinearMethod(LinearMethodBase):
layer.register_parameter("input_scale", None) layer.register_parameter("input_scale", None)
def process_weights_after_loading(self, layer: Module) -> None: def process_weights_after_loading(self, layer: Module) -> None:
# Block quant doesn't need to process weights after loading # TODO(rob): refactor block quant into separate class.
if self.block_quant: if self.block_quant:
assert self.quant_config.activation_scheme == "dynamic"
if current_platform.is_rocm(): if current_platform.is_rocm():
weight, weight_scale, _ = \ weight, weight_scale_inv, _ = \
normalize_e4m3fn_to_e4m3fnuz( normalize_e4m3fn_to_e4m3fnuz(
weight=layer.weight, weight=layer.weight,
weight_scale=layer.weight_scale_inv, weight_scale=layer.weight_scale_inv)
input_scale=layer.input_scale) else:
weight = layer.weight.data
weight_scale_inv = layer.weight_scale_inv.data
# Torch.compile cannot use Parameter subclasses.
layer.weight = Parameter(weight, requires_grad=False) layer.weight = Parameter(weight, requires_grad=False)
layer.weight_scale_inv = Parameter(weight_scale, layer.weight_scale_inv = Parameter(weight_scale_inv,
requires_grad=False) requires_grad=False)
return return
layer.weight = torch.nn.Parameter(layer.weight.data,
requires_grad=False)
# If checkpoint not serialized fp8, quantize the weights. # If checkpoint not serialized fp8, quantize the weights.
if not self.quant_config.is_checkpoint_fp8_serialized: if not self.quant_config.is_checkpoint_fp8_serialized:
qweight, weight_scale = ops.scaled_fp8_quant(layer.weight, qweight, weight_scale = ops.scaled_fp8_quant(layer.weight,
...@@ -507,8 +511,9 @@ class Fp8MoEMethod(FusedMoEMethodBase): ...@@ -507,8 +511,9 @@ class Fp8MoEMethod(FusedMoEMethodBase):
layer.w2_input_scale = None layer.w2_input_scale = None
def process_weights_after_loading(self, layer: Module) -> None: def process_weights_after_loading(self, layer: Module) -> None:
# Block quant doesn't need to process weights after loading # TODO (rob): refactor block quant into separate class.
if self.block_quant: if self.block_quant:
assert self.quant_config.activation_scheme == "dynamic"
if current_platform.is_rocm(): if current_platform.is_rocm():
w13_weight, w13_weight_scale_inv, w13_input_scale = \ w13_weight, w13_weight_scale_inv, w13_input_scale = \
normalize_e4m3fn_to_e4m3fnuz( normalize_e4m3fn_to_e4m3fnuz(
...@@ -518,22 +523,21 @@ class Fp8MoEMethod(FusedMoEMethodBase): ...@@ -518,22 +523,21 @@ class Fp8MoEMethod(FusedMoEMethodBase):
normalize_e4m3fn_to_e4m3fnuz( normalize_e4m3fn_to_e4m3fnuz(
layer.w2_weight, layer.w2_weight_scale_inv, layer.w2_weight, layer.w2_weight_scale_inv,
layer.w2_input_scale) layer.w2_input_scale)
# Reset the parameter else:
layer.w13_weight = torch.nn.Parameter(w13_weight, w13_weight = layer.w13_weight.data
w13_weight_scale_inv = layer.w13_weight_scale_inv.data
w2_weight = layer.w2_weight
w2_weight_scale_inv = layer.w2_weight_scale_inv
# torch.compile() cannot use Parameter subclasses.
layer.w13_weight = Parameter(w13_weight, requires_grad=False)
layer.w13_weight_scale_inv = Parameter(w13_weight_scale_inv,
requires_grad=False) requires_grad=False)
layer.w13_weight_scale_inv = torch.nn.Parameter( layer.w2_weight = Parameter(w2_weight, requires_grad=False)
w13_weight_scale_inv, requires_grad=False) layer.w2_weight_scale_inv = Parameter(w2_weight_scale_inv,
if w13_input_scale is not None:
layer.w13_input_scale = torch.nn.Parameter(
w13_input_scale, requires_grad=False)
layer.w2_weight = torch.nn.Parameter(w2_weight,
requires_grad=False) requires_grad=False)
layer.w2_weight_scale_inv = torch.nn.Parameter(
w2_weight_scale_inv, requires_grad=False)
if w2_input_scale is not None:
layer.w2_input_scale = torch.nn.Parameter(
w2_input_scale, requires_grad=False)
return return
# If checkpoint is fp16, quantize in place. # If checkpoint is fp16, quantize in place.
if not self.quant_config.is_checkpoint_fp8_serialized: if not self.quant_config.is_checkpoint_fp8_serialized:
# If rocm, use float8_e4m3fnuz as dtype # If rocm, use float8_e4m3fnuz as dtype
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment