Unverified Commit b1852307 authored by Robert Shaw's avatar Robert Shaw Committed by GitHub
Browse files

[ Misc ] Remove `fp8_shard_indexer` from Col/Row Parallel Linear (Simplify Weight Loading) (#5928)


Co-authored-by: default avatarRobert Shaw <rshaw@neuralmagic>
parent 6a2d659d
...@@ -269,10 +269,6 @@ class ColumnParallelLinear(LinearBase): ...@@ -269,10 +269,6 @@ class ColumnParallelLinear(LinearBase):
self.register_parameter("bias", None) self.register_parameter("bias", None)
def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
# Special case for Fp8 scales.
fp8_scales_shard_indexer = getattr(param, "fp8_scales_shard_indexer",
None)
tp_rank = get_tensor_model_parallel_rank() tp_rank = get_tensor_model_parallel_rank()
output_dim = getattr(param, "output_dim", None) output_dim = getattr(param, "output_dim", None)
param_data = param.data param_data = param.data
...@@ -281,11 +277,11 @@ class ColumnParallelLinear(LinearBase): ...@@ -281,11 +277,11 @@ class ColumnParallelLinear(LinearBase):
start_idx = tp_rank * shard_size start_idx = tp_rank * shard_size
loaded_weight = loaded_weight.narrow(output_dim, start_idx, loaded_weight = loaded_weight.narrow(output_dim, start_idx,
shard_size) shard_size)
# Special case for Fp8 scales.
elif fp8_scales_shard_indexer is not None: # Special case for loading scales off disk, which often do not
param_data, loaded_weight = fp8_scales_shard_indexer(param_data, # have a shape (such as in the case of AutoFP8).
loaded_weight, if len(loaded_weight.shape) == 0:
shard_id=0) loaded_weight = loaded_weight.reshape(1)
assert param_data.shape == loaded_weight.shape assert param_data.shape == loaded_weight.shape
param_data.copy_(loaded_weight) param_data.copy_(loaded_weight)
...@@ -751,10 +747,6 @@ class RowParallelLinear(LinearBase): ...@@ -751,10 +747,6 @@ class RowParallelLinear(LinearBase):
self.register_parameter("bias", None) self.register_parameter("bias", None)
def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
# Special case for Fp8 scales.
fp8_scales_shard_indexer = getattr(param, "fp8_scales_shard_indexer",
None)
tp_rank = get_tensor_model_parallel_rank() tp_rank = get_tensor_model_parallel_rank()
input_dim = getattr(param, "input_dim", None) input_dim = getattr(param, "input_dim", None)
param_data = param.data param_data = param.data
...@@ -764,13 +756,9 @@ class RowParallelLinear(LinearBase): ...@@ -764,13 +756,9 @@ class RowParallelLinear(LinearBase):
loaded_weight = loaded_weight.narrow(input_dim, start_idx, loaded_weight = loaded_weight.narrow(input_dim, start_idx,
shard_size) shard_size)
# Special case for Fp8 scales. # Special case for loading scales off disk, which often do not
elif fp8_scales_shard_indexer is not None: # have a shape (such as in the case of AutoFP8).
param_data, loaded_weight = fp8_scales_shard_indexer(param_data, if len(loaded_weight.shape) == 0:
loaded_weight,
shard_id=0)
if fp8_scales_shard_indexer is None and len(loaded_weight.shape) == 0:
loaded_weight = loaded_weight.reshape(1) loaded_weight = loaded_weight.reshape(1)
assert param_data.shape == loaded_weight.shape assert param_data.shape == loaded_weight.shape
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment