Unverified Commit 9dd10678 authored by pengcheng888's avatar pengcheng888 Committed by GitHub
Browse files

Merge pull request #101 from pengcheng888/issue/89

issue/89 在python的llama中使用matmul函数、以及减少Tensor对象创建次数
parents 36f8eab7 7e59976b
......@@ -86,6 +86,7 @@ def test(
infini_device=infinicore.device("cpu", 0),
backend="python",
):
model_path = os.path.expanduser(model_path)
# ---------------------------------------------------------------------------- #
# 创建模型,
# ---------------------------------------------------------------------------- #
......@@ -104,14 +105,12 @@ def test(
model.load_state_dict(model_param_infini)
config = model.config
# ---------------------------------------------------------------------------- #
# 创建 tokenizer
# ---------------------------------------------------------------------------- #
tokenizer = AutoTokenizer.from_pretrained(model_path)
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
if "llama" == config.model_type:
if "llama" == model.config.model_type:
backend = getattr(tokenizer, "backend_tokenizer", None)
target = getattr(backend, "_tokenizer", backend)
norm = getattr(target, "normalizer", None)
......@@ -129,7 +128,7 @@ def test(
]
)
else:
raise ValueError(f"Unsupported model type: {config.model_type}")
raise ValueError(f"Unsupported model type: {model.config.model_type}")
# ---------------------------------------------------------------------------- #
# token编码
......@@ -162,7 +161,6 @@ def test(
max_new_tokens=max_new_tokens,
device=infini_device,
tokenizer=tokenizer,
config=config,
)
t2 = time.time()
......
......@@ -65,12 +65,12 @@ class DynamicLayer(CacheLayerMixin):
self.max_seq_len = max(self.max_position_embeddings, seq_len)
self.keys = infinicore.empty(
[batch_size, self.max_seq_len, num_heads, head_dim],
(batch_size, self.max_seq_len, num_heads, head_dim),
dtype=dtype,
device=device,
)
self.values = infinicore.empty(
[batch_size, self.max_seq_len, num_heads, head_dim],
(batch_size, self.max_seq_len, num_heads, head_dim),
dtype=dtype,
device=device,
)
......@@ -80,12 +80,12 @@ class DynamicLayer(CacheLayerMixin):
self.max_seq_len = max(self.max_seq_len * 2, self.cache_position + seq_len)
keys_new = infinicore.empty(
[batch_size, self.max_seq_len, num_heads, head_dim],
(batch_size, self.max_seq_len, num_heads, head_dim),
dtype=dtype,
device=device,
)
values_new = infinicore.empty(
[batch_size, self.max_seq_len, num_heads, head_dim],
(batch_size, self.max_seq_len, num_heads, head_dim),
dtype=dtype,
device=device,
)
......
......@@ -121,7 +121,6 @@ class GenerationMixin:
max_new_tokens: int,
device: infinicore.device,
tokenizer,
config,
**kwargs,
):
model_kwargs = kwargs
......@@ -144,7 +143,6 @@ class GenerationMixin:
max_new_tokens=max_new_tokens,
device=device,
tokenizer=tokenizer,
config=config,
**model_kwargs,
)
return result
......@@ -155,7 +153,6 @@ class GenerationMixin:
max_new_tokens: int,
device: infinicore.device,
tokenizer,
config,
**model_kwargs,
):
r"""
......@@ -170,7 +167,7 @@ class GenerationMixin:
batch_size, seq_len = input_ids.shape[:2]
eos_token_id = config.eos_token_id
eos_token_id = self.config.eos_token_id
eos_token_id_list = (
[eos_token_id] if isinstance(eos_token_id, int) else eos_token_id
)
......@@ -216,7 +213,7 @@ class GenerationMixin:
device=token_scores.device,
)
for i in range(0, batch_size):
score = token_scores.narrow(0, i, 1).view([vocab_size])
score = token_scores.narrow(0, i, 1).view((vocab_size,))
out = next_tokens.narrow(0, i, 1).view([])
infinicore.nn.functional.random_sample(
score,
......@@ -247,16 +244,16 @@ class GenerationMixin:
break
print("\n</s>")
print(f"\n\n\n Generation completed in {round(sum(time_list),2)} ms")
print(f"\n\n\n Generation completed in {round(sum(time_list), 2)} ms")
print(
f" Batchsize={batch_size} Per_Batch_Input_Len={seq_len} Per_Batch_New_Tokens={len(time_list)}\n"
)
print(
f" Prefill TTFT: {round(time_list[0], 2)}ms Throughput: {round((1000 * batch_size * seq_len)/time_list[0], 2)}tok/s\n",
f" Prefill TTFT: {round(time_list[0], 2)}ms Throughput: {round((1000 * batch_size * seq_len) / time_list[0], 2)}tok/s\n",
)
if len(time_list) > 1:
print(
f" Decode Avg ITL: {round(sum(time_list[1:]) / (len(time_list) - 1), 2)}ms Throughput: {round((1000 * batch_size * (len(time_list) - 1))/ sum(time_list[1:]), 2)}tok/s\n",
f" Decode Avg ITL: {round(sum(time_list[1:]) / (len(time_list) - 1), 2)}ms Throughput: {round((1000 * batch_size * (len(time_list) - 1)) / sum(time_list[1:]), 2)}tok/s\n",
)
return output_tokens_list, output_content
......@@ -62,13 +62,8 @@ def multi_head_attention(
# [num_heads, seq_len, head_dim] @ [ num_heads, head_dim, total_seq_len]
# => [ num_heads, seq_len, total_seq_len]
attn_weight = Q @ K.permute((1, 2, 0))
scaling = infinicore.from_list(
[scaling], dtype=attn_weight.dtype, device=attn_weight.device
).as_strided(attn_weight.shape, [0, 0, 0])
attn_weight = attn_weight * scaling
# Q @ K.T *scaling
attn_weight = infinicore.matmul(Q, K.permute((1, 2, 0)), alpha=scaling)
infinicore.nn.functional.causal_softmax(attn_weight, out=attn_weight)
......@@ -169,6 +164,8 @@ class LlamaAttention(infinicore.nn.Module):
**kwargs,
)
self.attn_output = None # Variable reuse
def forward(
self,
hidden_states: infinicore.Tensor,
......@@ -184,7 +181,7 @@ class LlamaAttention(infinicore.nn.Module):
values_shape = (bs, seq_len, self.num_key_value_heads, self.head_dim)
# --------------------------------------------------------------------------------------- #
# 对 Q,KV进行 project
# 对 Q,K,V进行 project
# --------------------------------------------------------------------------------------- #
# => [bs, seq_len, num_attention_heads, head_dim]
query_states = self.q_proj(hidden_states).view(querys_shape)
......@@ -196,13 +193,9 @@ class LlamaAttention(infinicore.nn.Module):
value_states = self.v_proj(hidden_states).view(values_shape)
# --------------------------------------------------------------------------------------- #
# 对 Q和K 加上 rope
# 对 Q和K 加上 rope
# --------------------------------------------------------------------------------------- #
position_ids = kwargs.pop("position_ids", None)
if position_ids is None:
raise KeyError("position_ids error")
if rope_instance is None:
raise KeyError("rope_instance error")
query_states = rope_instance(query_states, position_ids)
key_states = rope_instance(key_states, position_ids)
......@@ -223,7 +216,14 @@ class LlamaAttention(infinicore.nn.Module):
# 注意力计算
# --------------------------------------------------------------------------------------- #
total_seq_len = key_states_total.shape[1]
attn_output = infinicore.empty_like(query_states)
if self.attn_output is None or self.attn_output.shape[1] != seq_len:
self.attn_output = infinicore.empty(
(bs, seq_len, self.num_attention_heads, self.head_dim),
dtype=query_states.dtype,
device=query_states.device,
)
for i in range(0, bs):
query_states_i = query_states.narrow(0, i, 1).view(
(seq_len, self.num_attention_heads, self.head_dim)
......@@ -235,7 +235,7 @@ class LlamaAttention(infinicore.nn.Module):
(total_seq_len, self.num_key_value_heads, self.head_dim)
)
attn_output_i = attn_output.narrow(0, i, 1).view(
attn_output_i = self.attn_output.narrow(0, i, 1).view(
(seq_len, self.num_attention_heads, self.head_dim)
)
......@@ -249,8 +249,9 @@ class LlamaAttention(infinicore.nn.Module):
# out project
# --------------------------------------------------------------------------------------- #
# ([bs, seq_len, num_attention_heads, head_dim]) ==> [bs, seq_len, hidden_size ]
attn_output = attn_output.view(hidden_states_shape)
attn_output = self.attn_output.view(
(bs, seq_len, self.num_attention_heads * self.head_dim)
)
# o_proj
return self.o_proj(attn_output)
......@@ -292,7 +293,7 @@ class LlamaDecoderLayer(infinicore.nn.Module):
**kwargs,
)
hidden_states = residual + hidden_states
hidden_states += residual
# ------------------------------------------------ #
# Fully Connected
......@@ -303,7 +304,7 @@ class LlamaDecoderLayer(infinicore.nn.Module):
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
hidden_states += residual
return hidden_states
......@@ -375,7 +376,10 @@ class LlamaModel(infinicore.nn.Module):
# norm
# --------------------------------------------------------- #
seq_len = hidden_states.shape[1]
last_token = hidden_states.narrow(1, seq_len - 1, 1)
if seq_len > 1:
last_token = hidden_states.narrow(1, seq_len - 1, 1)
else:
last_token = hidden_states
return self.norm(last_token)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment