Unverified Commit 4f858475 authored by zhaoyang-star's avatar zhaoyang-star Committed by GitHub
Browse files

Fix mqa is false case in gpt_bigcode (#806)

parent 65fc1c31
...@@ -49,10 +49,11 @@ class GPTBigCodeAttention(nn.Module): ...@@ -49,10 +49,11 @@ class GPTBigCodeAttention(nn.Module):
super().__init__() super().__init__()
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
total_num_heads = config.num_attention_heads total_num_heads = config.num_attention_heads
tensor_model_parallel_world_size = ( self.tensor_model_parallel_world_size = (
get_tensor_model_parallel_world_size()) get_tensor_model_parallel_world_size())
assert total_num_heads % tensor_model_parallel_world_size == 0 assert total_num_heads % self.tensor_model_parallel_world_size == 0
self.num_heads = total_num_heads // tensor_model_parallel_world_size self.num_heads = (total_num_heads //
self.tensor_model_parallel_world_size)
self.head_dim = self.hidden_size // total_num_heads self.head_dim = self.hidden_size // total_num_heads
self.scale = self.head_dim**-0.5 self.scale = self.head_dim**-0.5
...@@ -101,7 +102,10 @@ class GPTBigCodeAttention(nn.Module): ...@@ -101,7 +102,10 @@ class GPTBigCodeAttention(nn.Module):
k, v = kv.split([self.kv_dim, self.kv_dim], dim=-1) k, v = kv.split([self.kv_dim, self.kv_dim], dim=-1)
else: else:
qkv, _ = self.c_attn(hidden_states) qkv, _ = self.c_attn(hidden_states)
q, k, v = qkv.split([self.hidden_size, self.kv_dim, self.kv_dim], q, k, v = qkv.split([
self.hidden_size // self.tensor_model_parallel_world_size,
self.kv_dim, self.kv_dim
],
dim=-1) dim=-1)
key_cache, value_cache = kv_cache key_cache, value_cache = kv_cache
attn_output = self.attn(q, k, v, key_cache, value_cache, attn_output = self.attn(q, k, v, key_cache, value_cache,
...@@ -255,8 +259,6 @@ class GPTBigCodeForCausalLM(nn.Module): ...@@ -255,8 +259,6 @@ class GPTBigCodeForCausalLM(nn.Module):
model_name_or_path: str, model_name_or_path: str,
cache_dir: Optional[str] = None, cache_dir: Optional[str] = None,
use_np_cache: bool = False): use_np_cache: bool = False):
tensor_model_parallel_world_size = (
get_tensor_model_parallel_world_size())
tensor_model_parallel_rank = get_tensor_model_parallel_rank() tensor_model_parallel_rank = get_tensor_model_parallel_rank()
state_dict = self.state_dict() state_dict = self.state_dict()
...@@ -286,7 +288,8 @@ class GPTBigCodeForCausalLM(nn.Module): ...@@ -286,7 +288,8 @@ class GPTBigCodeForCausalLM(nn.Module):
hidden_size = self.config.hidden_size hidden_size = self.config.hidden_size
head_size = hidden_size // total_num_heads head_size = hidden_size // total_num_heads
total_kv_size = head_size * total_num_kv_heads total_kv_size = head_size * total_num_kv_heads
num_heads = total_num_heads // tensor_model_parallel_world_size num_heads = (total_num_heads //
self.tensor_model_parallel_world_size)
head_start = tensor_model_parallel_rank * num_heads head_start = tensor_model_parallel_rank * num_heads
head_end = (tensor_model_parallel_rank + 1) * num_heads head_end = (tensor_model_parallel_rank + 1) * num_heads
...@@ -326,7 +329,7 @@ class GPTBigCodeForCausalLM(nn.Module): ...@@ -326,7 +329,7 @@ class GPTBigCodeForCausalLM(nn.Module):
if name == "transformer.wte.weight": if name == "transformer.wte.weight":
# Consider padding in the vocab size. # Consider padding in the vocab size.
padded_vocab_size = param.shape[ padded_vocab_size = param.shape[
0] * tensor_model_parallel_world_size 0] * self.tensor_model_parallel_world_size
num_extra_rows = padded_vocab_size - self.config.vocab_size num_extra_rows = padded_vocab_size - self.config.vocab_size
extra_rows = torch.empty(num_extra_rows, extra_rows = torch.empty(num_extra_rows,
loaded_weight.shape[1]) loaded_weight.shape[1])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment