Unverified Commit 5303c1ed authored by Ke Bao's avatar Ke Bao Committed by GitHub
Browse files

Support Mistral-Nemo (#691)

parent 65bd1338
...@@ -70,6 +70,7 @@ class LlamaMLP(nn.Module): ...@@ -70,6 +70,7 @@ class LlamaMLP(nn.Module):
class LlamaAttention(nn.Module): class LlamaAttention(nn.Module):
def __init__( def __init__(
self, self,
config: LlamaConfig,
hidden_size: int, hidden_size: int,
num_heads: int, num_heads: int,
num_kv_heads: int, num_kv_heads: int,
...@@ -96,7 +97,10 @@ class LlamaAttention(nn.Module): ...@@ -96,7 +97,10 @@ class LlamaAttention(nn.Module):
# the KV heads across multiple tensor parallel GPUs. # the KV heads across multiple tensor parallel GPUs.
assert tp_size % self.total_num_kv_heads == 0 assert tp_size % self.total_num_kv_heads == 0
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
self.head_dim = hidden_size // self.total_num_heads # MistralConfig has an optional head_dim introduced by Mistral-Nemo
self.head_dim = getattr(
config, "head_dim", self.hidden_size // self.total_num_heads
)
self.q_size = self.num_heads * self.head_dim self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim self.kv_size = self.num_kv_heads * self.head_dim
self.scaling = self.head_dim**-0.5 self.scaling = self.head_dim**-0.5
...@@ -168,6 +172,7 @@ class LlamaDecoderLayer(nn.Module): ...@@ -168,6 +172,7 @@ class LlamaDecoderLayer(nn.Module):
rope_is_neox_style = getattr(config, "rope_is_neox_style", True) rope_is_neox_style = getattr(config, "rope_is_neox_style", True)
max_position_embeddings = getattr(config, "max_position_embeddings", 8192) max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
self.self_attn = LlamaAttention( self.self_attn = LlamaAttention(
config=config,
hidden_size=self.hidden_size, hidden_size=self.hidden_size,
num_heads=config.num_attention_heads, num_heads=config.num_attention_heads,
num_kv_heads=config.num_key_value_heads, num_kv_heads=config.num_key_value_heads,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment