Unverified Commit 9dd10678 authored by pengcheng888's avatar pengcheng888 Committed by GitHub
Browse files

Merge pull request #101 from pengcheng888/issue/89

issue/89 在python的llama中使用matmul函数、以及减少Tensor对象创建次数
parents 36f8eab7 7e59976b
...@@ -86,6 +86,7 @@ def test( ...@@ -86,6 +86,7 @@ def test(
infini_device=infinicore.device("cpu", 0), infini_device=infinicore.device("cpu", 0),
backend="python", backend="python",
): ):
model_path = os.path.expanduser(model_path)
# ---------------------------------------------------------------------------- # # ---------------------------------------------------------------------------- #
# 创建模型, # 创建模型,
# ---------------------------------------------------------------------------- # # ---------------------------------------------------------------------------- #
...@@ -104,14 +105,12 @@ def test( ...@@ -104,14 +105,12 @@ def test(
model.load_state_dict(model_param_infini) model.load_state_dict(model_param_infini)
config = model.config
# ---------------------------------------------------------------------------- # # ---------------------------------------------------------------------------- #
# 创建 tokenizer # 创建 tokenizer
# ---------------------------------------------------------------------------- # # ---------------------------------------------------------------------------- #
tokenizer = AutoTokenizer.from_pretrained(model_path) tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
if "llama" == config.model_type: if "llama" == model.config.model_type:
backend = getattr(tokenizer, "backend_tokenizer", None) backend = getattr(tokenizer, "backend_tokenizer", None)
target = getattr(backend, "_tokenizer", backend) target = getattr(backend, "_tokenizer", backend)
norm = getattr(target, "normalizer", None) norm = getattr(target, "normalizer", None)
...@@ -129,7 +128,7 @@ def test( ...@@ -129,7 +128,7 @@ def test(
] ]
) )
else: else:
raise ValueError(f"Unsupported model type: {config.model_type}") raise ValueError(f"Unsupported model type: {model.config.model_type}")
# ---------------------------------------------------------------------------- # # ---------------------------------------------------------------------------- #
# token编码 # token编码
...@@ -162,7 +161,6 @@ def test( ...@@ -162,7 +161,6 @@ def test(
max_new_tokens=max_new_tokens, max_new_tokens=max_new_tokens,
device=infini_device, device=infini_device,
tokenizer=tokenizer, tokenizer=tokenizer,
config=config,
) )
t2 = time.time() t2 = time.time()
......
...@@ -65,12 +65,12 @@ class DynamicLayer(CacheLayerMixin): ...@@ -65,12 +65,12 @@ class DynamicLayer(CacheLayerMixin):
self.max_seq_len = max(self.max_position_embeddings, seq_len) self.max_seq_len = max(self.max_position_embeddings, seq_len)
self.keys = infinicore.empty( self.keys = infinicore.empty(
[batch_size, self.max_seq_len, num_heads, head_dim], (batch_size, self.max_seq_len, num_heads, head_dim),
dtype=dtype, dtype=dtype,
device=device, device=device,
) )
self.values = infinicore.empty( self.values = infinicore.empty(
[batch_size, self.max_seq_len, num_heads, head_dim], (batch_size, self.max_seq_len, num_heads, head_dim),
dtype=dtype, dtype=dtype,
device=device, device=device,
) )
...@@ -80,12 +80,12 @@ class DynamicLayer(CacheLayerMixin): ...@@ -80,12 +80,12 @@ class DynamicLayer(CacheLayerMixin):
self.max_seq_len = max(self.max_seq_len * 2, self.cache_position + seq_len) self.max_seq_len = max(self.max_seq_len * 2, self.cache_position + seq_len)
keys_new = infinicore.empty( keys_new = infinicore.empty(
[batch_size, self.max_seq_len, num_heads, head_dim], (batch_size, self.max_seq_len, num_heads, head_dim),
dtype=dtype, dtype=dtype,
device=device, device=device,
) )
values_new = infinicore.empty( values_new = infinicore.empty(
[batch_size, self.max_seq_len, num_heads, head_dim], (batch_size, self.max_seq_len, num_heads, head_dim),
dtype=dtype, dtype=dtype,
device=device, device=device,
) )
......
...@@ -121,7 +121,6 @@ class GenerationMixin: ...@@ -121,7 +121,6 @@ class GenerationMixin:
max_new_tokens: int, max_new_tokens: int,
device: infinicore.device, device: infinicore.device,
tokenizer, tokenizer,
config,
**kwargs, **kwargs,
): ):
model_kwargs = kwargs model_kwargs = kwargs
...@@ -144,7 +143,6 @@ class GenerationMixin: ...@@ -144,7 +143,6 @@ class GenerationMixin:
max_new_tokens=max_new_tokens, max_new_tokens=max_new_tokens,
device=device, device=device,
tokenizer=tokenizer, tokenizer=tokenizer,
config=config,
**model_kwargs, **model_kwargs,
) )
return result return result
...@@ -155,7 +153,6 @@ class GenerationMixin: ...@@ -155,7 +153,6 @@ class GenerationMixin:
max_new_tokens: int, max_new_tokens: int,
device: infinicore.device, device: infinicore.device,
tokenizer, tokenizer,
config,
**model_kwargs, **model_kwargs,
): ):
r""" r"""
...@@ -170,7 +167,7 @@ class GenerationMixin: ...@@ -170,7 +167,7 @@ class GenerationMixin:
batch_size, seq_len = input_ids.shape[:2] batch_size, seq_len = input_ids.shape[:2]
eos_token_id = config.eos_token_id eos_token_id = self.config.eos_token_id
eos_token_id_list = ( eos_token_id_list = (
[eos_token_id] if isinstance(eos_token_id, int) else eos_token_id [eos_token_id] if isinstance(eos_token_id, int) else eos_token_id
) )
...@@ -216,7 +213,7 @@ class GenerationMixin: ...@@ -216,7 +213,7 @@ class GenerationMixin:
device=token_scores.device, device=token_scores.device,
) )
for i in range(0, batch_size): for i in range(0, batch_size):
score = token_scores.narrow(0, i, 1).view([vocab_size]) score = token_scores.narrow(0, i, 1).view((vocab_size,))
out = next_tokens.narrow(0, i, 1).view([]) out = next_tokens.narrow(0, i, 1).view([])
infinicore.nn.functional.random_sample( infinicore.nn.functional.random_sample(
score, score,
...@@ -247,16 +244,16 @@ class GenerationMixin: ...@@ -247,16 +244,16 @@ class GenerationMixin:
break break
print("\n</s>") print("\n</s>")
print(f"\n\n\n Generation completed in {round(sum(time_list),2)} ms") print(f"\n\n\n Generation completed in {round(sum(time_list), 2)} ms")
print( print(
f" Batchsize={batch_size} Per_Batch_Input_Len={seq_len} Per_Batch_New_Tokens={len(time_list)}\n" f" Batchsize={batch_size} Per_Batch_Input_Len={seq_len} Per_Batch_New_Tokens={len(time_list)}\n"
) )
print( print(
f" Prefill TTFT: {round(time_list[0], 2)}ms Throughput: {round((1000 * batch_size * seq_len)/time_list[0], 2)}tok/s\n", f" Prefill TTFT: {round(time_list[0], 2)}ms Throughput: {round((1000 * batch_size * seq_len) / time_list[0], 2)}tok/s\n",
) )
if len(time_list) > 1: if len(time_list) > 1:
print( print(
f" Decode Avg ITL: {round(sum(time_list[1:]) / (len(time_list) - 1), 2)}ms Throughput: {round((1000 * batch_size * (len(time_list) - 1))/ sum(time_list[1:]), 2)}tok/s\n", f" Decode Avg ITL: {round(sum(time_list[1:]) / (len(time_list) - 1), 2)}ms Throughput: {round((1000 * batch_size * (len(time_list) - 1)) / sum(time_list[1:]), 2)}tok/s\n",
) )
return output_tokens_list, output_content return output_tokens_list, output_content
...@@ -62,13 +62,8 @@ def multi_head_attention( ...@@ -62,13 +62,8 @@ def multi_head_attention(
# [num_heads, seq_len, head_dim] @ [ num_heads, head_dim, total_seq_len] # [num_heads, seq_len, head_dim] @ [ num_heads, head_dim, total_seq_len]
# => [ num_heads, seq_len, total_seq_len] # => [ num_heads, seq_len, total_seq_len]
attn_weight = Q @ K.permute((1, 2, 0)) # Q @ K.T *scaling
attn_weight = infinicore.matmul(Q, K.permute((1, 2, 0)), alpha=scaling)
scaling = infinicore.from_list(
[scaling], dtype=attn_weight.dtype, device=attn_weight.device
).as_strided(attn_weight.shape, [0, 0, 0])
attn_weight = attn_weight * scaling
infinicore.nn.functional.causal_softmax(attn_weight, out=attn_weight) infinicore.nn.functional.causal_softmax(attn_weight, out=attn_weight)
...@@ -169,6 +164,8 @@ class LlamaAttention(infinicore.nn.Module): ...@@ -169,6 +164,8 @@ class LlamaAttention(infinicore.nn.Module):
**kwargs, **kwargs,
) )
self.attn_output = None # Variable reuse
def forward( def forward(
self, self,
hidden_states: infinicore.Tensor, hidden_states: infinicore.Tensor,
...@@ -184,7 +181,7 @@ class LlamaAttention(infinicore.nn.Module): ...@@ -184,7 +181,7 @@ class LlamaAttention(infinicore.nn.Module):
values_shape = (bs, seq_len, self.num_key_value_heads, self.head_dim) values_shape = (bs, seq_len, self.num_key_value_heads, self.head_dim)
# --------------------------------------------------------------------------------------- # # --------------------------------------------------------------------------------------- #
# 对 Q,KV进行 project # 对 Q,K,V进行 project
# --------------------------------------------------------------------------------------- # # --------------------------------------------------------------------------------------- #
# => [bs, seq_len, num_attention_heads, head_dim] # => [bs, seq_len, num_attention_heads, head_dim]
query_states = self.q_proj(hidden_states).view(querys_shape) query_states = self.q_proj(hidden_states).view(querys_shape)
...@@ -196,13 +193,9 @@ class LlamaAttention(infinicore.nn.Module): ...@@ -196,13 +193,9 @@ class LlamaAttention(infinicore.nn.Module):
value_states = self.v_proj(hidden_states).view(values_shape) value_states = self.v_proj(hidden_states).view(values_shape)
# --------------------------------------------------------------------------------------- # # --------------------------------------------------------------------------------------- #
# 对 Q和K 加上 rope # 对 Q和K 加上 rope
# --------------------------------------------------------------------------------------- # # --------------------------------------------------------------------------------------- #
position_ids = kwargs.pop("position_ids", None) position_ids = kwargs.pop("position_ids", None)
if position_ids is None:
raise KeyError("position_ids error")
if rope_instance is None:
raise KeyError("rope_instance error")
query_states = rope_instance(query_states, position_ids) query_states = rope_instance(query_states, position_ids)
key_states = rope_instance(key_states, position_ids) key_states = rope_instance(key_states, position_ids)
...@@ -223,7 +216,14 @@ class LlamaAttention(infinicore.nn.Module): ...@@ -223,7 +216,14 @@ class LlamaAttention(infinicore.nn.Module):
# 注意力计算 # 注意力计算
# --------------------------------------------------------------------------------------- # # --------------------------------------------------------------------------------------- #
total_seq_len = key_states_total.shape[1] total_seq_len = key_states_total.shape[1]
attn_output = infinicore.empty_like(query_states)
if self.attn_output is None or self.attn_output.shape[1] != seq_len:
self.attn_output = infinicore.empty(
(bs, seq_len, self.num_attention_heads, self.head_dim),
dtype=query_states.dtype,
device=query_states.device,
)
for i in range(0, bs): for i in range(0, bs):
query_states_i = query_states.narrow(0, i, 1).view( query_states_i = query_states.narrow(0, i, 1).view(
(seq_len, self.num_attention_heads, self.head_dim) (seq_len, self.num_attention_heads, self.head_dim)
...@@ -235,7 +235,7 @@ class LlamaAttention(infinicore.nn.Module): ...@@ -235,7 +235,7 @@ class LlamaAttention(infinicore.nn.Module):
(total_seq_len, self.num_key_value_heads, self.head_dim) (total_seq_len, self.num_key_value_heads, self.head_dim)
) )
attn_output_i = attn_output.narrow(0, i, 1).view( attn_output_i = self.attn_output.narrow(0, i, 1).view(
(seq_len, self.num_attention_heads, self.head_dim) (seq_len, self.num_attention_heads, self.head_dim)
) )
...@@ -249,8 +249,9 @@ class LlamaAttention(infinicore.nn.Module): ...@@ -249,8 +249,9 @@ class LlamaAttention(infinicore.nn.Module):
# out project # out project
# --------------------------------------------------------------------------------------- # # --------------------------------------------------------------------------------------- #
# ([bs, seq_len, num_attention_heads, head_dim]) ==> [bs, seq_len, hidden_size ] # ([bs, seq_len, num_attention_heads, head_dim]) ==> [bs, seq_len, hidden_size ]
attn_output = attn_output.view(hidden_states_shape) attn_output = self.attn_output.view(
(bs, seq_len, self.num_attention_heads * self.head_dim)
)
# o_proj # o_proj
return self.o_proj(attn_output) return self.o_proj(attn_output)
...@@ -292,7 +293,7 @@ class LlamaDecoderLayer(infinicore.nn.Module): ...@@ -292,7 +293,7 @@ class LlamaDecoderLayer(infinicore.nn.Module):
**kwargs, **kwargs,
) )
hidden_states = residual + hidden_states hidden_states += residual
# ------------------------------------------------ # # ------------------------------------------------ #
# Fully Connected # Fully Connected
...@@ -303,7 +304,7 @@ class LlamaDecoderLayer(infinicore.nn.Module): ...@@ -303,7 +304,7 @@ class LlamaDecoderLayer(infinicore.nn.Module):
hidden_states = self.mlp(hidden_states) hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states hidden_states += residual
return hidden_states return hidden_states
...@@ -375,7 +376,10 @@ class LlamaModel(infinicore.nn.Module): ...@@ -375,7 +376,10 @@ class LlamaModel(infinicore.nn.Module):
# norm # norm
# --------------------------------------------------------- # # --------------------------------------------------------- #
seq_len = hidden_states.shape[1] seq_len = hidden_states.shape[1]
if seq_len > 1:
last_token = hidden_states.narrow(1, seq_len - 1, 1) last_token = hidden_states.narrow(1, seq_len - 1, 1)
else:
last_token = hidden_states
return self.norm(last_token) return self.norm(last_token)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment