Unverified Commit 00a4e56d authored by Isotr0py's avatar Isotr0py Committed by GitHub
Browse files

[Bugfix] Fix broken deepseek fp8 TP weights loading (#24367)


Signed-off-by: default avatarIsotr0py <mozf@mail2.sysu.edu.cn>
parent 0eadaeff
......@@ -262,7 +262,7 @@ class LinearBase(CustomOp):
self.tp_size = (get_tensor_model_parallel_world_size()
if not disable_tp else 1)
def __post_init__(self):
def update_param_tp_status(self):
for param in self.parameters():
if isinstance(param, BasevLLMParameter):
param.tp_rank = self.tp_rank
......@@ -459,6 +459,7 @@ class ColumnParallelLinear(LinearBase):
})
else:
self.register_parameter("bias", None)
self.update_param_tp_status()
def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
......@@ -1250,6 +1251,7 @@ class RowParallelLinear(LinearBase):
})
else:
self.register_parameter("bias", None)
self.update_param_tp_status()
def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
input_dim = getattr(param, "input_dim", None)
......
......@@ -270,7 +270,8 @@ class Fp8LinearMethod(LinearMethodBase):
layer.weight_block_size = None
if self.block_quant:
tp_size = get_tensor_model_parallel_world_size()
tp_size = getattr(layer, "tp_size",
get_tensor_model_parallel_world_size())
assert self.quant_config.weight_block_size is not None
layer.weight_block_size = self.quant_config.weight_block_size
block_n, block_k = (
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment