Unverified Commit f0e15dc6 authored by Xiaoyu Zhang's avatar Xiaoyu Zhang Committed by GitHub
Browse files

[HotFix] fix fp8 scale load failed in tp>1 (#2837)

parent f1769586
...@@ -437,7 +437,7 @@ class ColumnParallelLinear(LinearBase): ...@@ -437,7 +437,7 @@ class ColumnParallelLinear(LinearBase):
if len(loaded_weight.shape) == 0: if len(loaded_weight.shape) == 0:
assert loaded_weight.numel() == 1 assert loaded_weight.numel() == 1
loaded_weight = loaded_weight.reshape(1) loaded_weight = loaded_weight.reshape(1)
load_column_parallel_weight(param, loaded_weight, self.tp_rank) param.load_column_parallel_weight(loaded_weight=loaded_weight)
def forward(self, input_): def forward(self, input_):
bias = self.bias if not self.skip_bias_add else None bias = self.bias if not self.skip_bias_add else None
...@@ -1247,12 +1247,7 @@ class RowParallelLinear(LinearBase): ...@@ -1247,12 +1247,7 @@ class RowParallelLinear(LinearBase):
assert loaded_weight.numel() == 1 assert loaded_weight.numel() == 1
loaded_weight = loaded_weight.reshape(1) loaded_weight = loaded_weight.reshape(1)
load_row_parallel_weight( param.load_row_parallel_weight(loaded_weight=loaded_weight)
param,
loaded_weight,
self.tp_rank,
use_presharded_weights=self.use_presharded_weights,
)
def forward(self, input_): def forward(self, input_):
if self.input_is_parallel: if self.input_is_parallel:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment