Commit 23c1ad07 authored by zhaoying1's avatar zhaoying1
Browse files

Update modeling_chatglm.py

parent 43fc60f5
...@@ -277,7 +277,7 @@ def attention_fn( ...@@ -277,7 +277,7 @@ def attention_fn(
q, k, v = [rearrange(x, 'b s h d -> b h s d') for x in [query_layer, key_layer, value_layer]] q, k, v = [rearrange(x, 'b s h d -> b h s d') for x in [query_layer, key_layer, value_layer]]
ctx_lens1 = ctx_lens.to(q.device) ctx_lens1 = ctx_lens.to(q.device)
output = flash_attn_func(q, k, v,query_key_layer_scaling_coeff,ctx_lens1) output = flash_attn_func(q, k, v,query_key_layer_scaling_coeff,ctx_lens1,3)
context_layer = output.permute(2, 0, 1, 3) context_layer = output.permute(2, 0, 1, 3)
new_context_layer_shape = context_layer.size()[:-2] + (hidden_size_per_partition,) new_context_layer_shape = context_layer.size()[:-2] + (hidden_size_per_partition,)
context_layer = context_layer.view(*new_context_layer_shape) context_layer = context_layer.view(*new_context_layer_shape)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment