Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
dad00c42
Commit
dad00c42
authored
Jul 14, 2023
by
klhhhhh
Committed by
Hongxin Liu
Aug 15, 2023
Browse files
[shardformer] support chatglm without layernorm
parent
cbb54d32
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
13 additions
and
6 deletions
+13
-6
tests/kit/model_zoo/transformers/chatglm2_6b/modeling_chatglm.py
...it/model_zoo/transformers/chatglm2_6b/modeling_chatglm.py
+13
-6
No files found.
tests/kit/model_zoo/transformers/chatglm2_6b/modeling_chatglm.py
View file @
dad00c42
...
...
@@ -396,17 +396,18 @@ class SelfAttention(torch.nn.Module):
self
.
num_multi_query_groups_per_partition
=
config
.
multi_query_group_num
self
.
qkv_hidden_size
=
(
self
.
projection_size
+
2
*
self
.
hidden_size_per_attention_head
*
config
.
multi_query_group_num
)
<<<<<<<
HEAD
self
.
query_key_value
=
nn
.
Linear
(
config
.
hidden_size
,
self
.
qkv_hidden_size
,
bias
=
config
.
add_bias_linear
or
config
.
add_qkv_bias
,
device
=
device
,
**
_config_to_kwargs
(
config
),
)
self
.
core_attention
=
CoreAttention
(
config
,
self
.
layer_number
)
# Output.
=======
self
.
query_key_value
=
nn
.
Linear
(
self
.
hidden_size
,
self
.
qkv_hidden_size
,
bias
=
config
.
add_bias_linear
or
config
.
add_qkv_bias
,
<<<<<<<
HEAD
self
.
dense
=
nn
.
Linear
(
self
.
projection_size
,
config
.
hidden_size
,
...
...
@@ -414,6 +415,13 @@ class SelfAttention(torch.nn.Module):
device
=
device
,
**
_config_to_kwargs
(
config
),
)
=======
self
.
dense
=
nn
.
Linear
(
self
.
projection_size
,
self
.
hidden_size
,
bias
=
config
.
add_bias_linear
,
device
=
device
,
**
_config_to_kwargs
(
config
))
>>>>>>>
[
shardformer
]
support
chatglm
without
layernorm
def
_allocate_memory
(
self
,
inference_max_sequence_len
,
batch_size
,
device
=
None
,
dtype
=
None
):
if
self
.
multi_query_attention
:
...
...
@@ -925,7 +933,6 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
if
inputs_embeds
is
None
:
inputs_embeds
=
self
.
embedding
(
input_ids
)
print
(
inputs_embeds
)
if
self
.
pre_seq_len
is
not
None
:
if
past_key_values
is
None
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment