Unverified Commit d3a24513 authored by Baoyuan Qi's avatar Baoyuan Qi Committed by GitHub
Browse files

[Bugfix]fix and needs_scalar_to_array logic check (#6238)


Co-authored-by: default avatarRobert Shaw <114415538+robertgshaw2-neuralmagic@users.noreply.github.com>
parent 673dd4ca
...@@ -387,7 +387,7 @@ class MergedColumnParallelLinear(ColumnParallelLinear): ...@@ -387,7 +387,7 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
if loaded_shard_id is None: if loaded_shard_id is None:
# Loaded weight is already fused on disk (qkv/mlp). # Loaded weight is already fused on disk (qkv/mlp).
if output_dim is None: if output_dim is None:
if needs_scalar_to_array is not None: if needs_scalar_to_array:
param_data, loaded_weight = adjust_scalar_to_fused_array( param_data, loaded_weight = adjust_scalar_to_fused_array(
param_data, loaded_weight, 0) param_data, loaded_weight, 0)
...@@ -549,7 +549,7 @@ class QKVParallelLinear(ColumnParallelLinear): ...@@ -549,7 +549,7 @@ class QKVParallelLinear(ColumnParallelLinear):
if loaded_shard_id is None: if loaded_shard_id is None:
# Loaded weight is already fused on disk (qkv/mlp). # Loaded weight is already fused on disk (qkv/mlp).
if output_dim is None: if output_dim is None:
if needs_scalar_to_array is not None: if needs_scalar_to_array:
param_data, loaded_weight = adjust_scalar_to_fused_array( param_data, loaded_weight = adjust_scalar_to_fused_array(
param_data, loaded_weight, 0) param_data, loaded_weight, 0)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment