Commit 12d93ad7 authored by huangwb's avatar huangwb
Browse files

Merge branch 'dev-rocm-support-baichuan' into 'dev-rocm'

fix baichuan config init bug

See merge request huangwb1/text-generation-inference!3
parents 52647592 d461d955
......@@ -39,7 +39,7 @@ from text_generation_server.utils.layers import (
def load_attention(config, prefix, weights):
if config.num_attention_heads != config.num_key_value_heads:
if hasattr(config, 'num_key_value_heads') and config.num_attention_heads != config.num_key_value_heads:
return _load_gqa(config, prefix, weights)
else:
if config.model_type == "baichuan":
......@@ -107,7 +107,7 @@ class FlashLlamaAttention(torch.nn.Module):
self.rotary_emb = PositionRotaryEmbedding.static(
config=config,
dim=self.head_size,
base=config.rope_theta,
base = config.rope_theta if hasattr(config, 'rope_theta') else 10000,
device=weights.device,
)
......@@ -121,7 +121,7 @@ class FlashLlamaAttention(torch.nn.Module):
self.num_heads = self.num_heads // weights.process_group.size()
self.num_key_value_heads = (
config.num_key_value_heads // weights.process_group.size()
)
) if hasattr(config,'num_key_value_heads') else (config.num_attention_heads // weights.process_group.size())
self.query_key_value = load_attention(config, prefix, weights)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment