Unverified Commit cbced7f0 authored by drbh's avatar drbh Committed by GitHub
Browse files

feat: adjust attn weight loading logic (#1975)

This PR updates `load_attention` to prefer loading specific attention
based on the model type. Additionally there were two cases where
`TensorParallelColumnLinear.load_multi` was called and this reduces it
to a single path
parent 612bc483
...@@ -49,30 +49,24 @@ if SYSTEM == "rocm": ...@@ -49,30 +49,24 @@ if SYSTEM == "rocm":
def load_attention(config, prefix, weights): def load_attention(config, prefix, weights):
bias = config.attention_bias bias = config.attention_bias
if config.num_attention_heads != config.num_key_value_heads:
return TensorParallelColumnLinear.load_multi( # if specific model type, load the correct attention
config, if config.model_type == "phi3":
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
dim=0,
weights=weights,
bias=bias,
)
else:
if config.model_type == "baichuan":
return TensorParallelColumnLinear.load_qkv( return TensorParallelColumnLinear.load_qkv(
config, config,
prefix=f"{prefix}.W_pack", prefix=f"{prefix}.qkv_proj",
weights=weights, weights=weights,
bias=bias, bias=bias,
) )
elif config.model_type == "phi3": elif config.model_type == "baichuan":
return TensorParallelColumnLinear.load_qkv( return TensorParallelColumnLinear.load_qkv(
config, config,
prefix=f"{prefix}.qkv_proj", prefix=f"{prefix}.W_pack",
weights=weights, weights=weights,
bias=bias, bias=bias,
) )
else:
# otherwise, load the default attention based on the number of heads
return TensorParallelColumnLinear.load_multi( return TensorParallelColumnLinear.load_multi(
config, config,
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"], prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment