Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
cbb54d32
Commit
cbb54d32
authored
Jul 13, 2023
by
klhhhhh
Committed by
Hongxin Liu
Aug 15, 2023
Browse files
[shardformer] polish code
parent
1a29e8fc
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
6 additions
and
0 deletions
+6
-0
tests/kit/model_zoo/transformers/chatglm2_6b/modeling_chatglm.py
...it/model_zoo/transformers/chatglm2_6b/modeling_chatglm.py
+6
-0
No files found.
tests/kit/model_zoo/transformers/chatglm2_6b/modeling_chatglm.py
View file @
cbb54d32
...
@@ -80,6 +80,7 @@ def default_init(cls, *args, **kwargs):
...
@@ -80,6 +80,7 @@ def default_init(cls, *args, **kwargs):
class
InvalidScoreLogitsProcessor
(
LogitsProcessor
):
class
InvalidScoreLogitsProcessor
(
LogitsProcessor
):
def
__call__
(
self
,
input_ids
:
torch
.
LongTensor
,
scores
:
torch
.
FloatTensor
)
->
torch
.
FloatTensor
:
def
__call__
(
self
,
input_ids
:
torch
.
LongTensor
,
scores
:
torch
.
FloatTensor
)
->
torch
.
FloatTensor
:
if
torch
.
isnan
(
scores
).
any
()
or
torch
.
isinf
(
scores
).
any
():
if
torch
.
isnan
(
scores
).
any
()
or
torch
.
isinf
(
scores
).
any
():
scores
.
zero_
()
scores
.
zero_
()
...
@@ -219,6 +220,7 @@ def apply_rotary_pos_emb(x: torch.Tensor, rope_cache: torch.Tensor) -> torch.Ten
...
@@ -219,6 +220,7 @@ def apply_rotary_pos_emb(x: torch.Tensor, rope_cache: torch.Tensor) -> torch.Ten
class
RMSNorm
(
torch
.
nn
.
Module
):
class
RMSNorm
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
normalized_shape
,
eps
=
1e-5
,
device
=
None
,
dtype
=
None
,
**
kwargs
):
def
__init__
(
self
,
normalized_shape
,
eps
=
1e-5
,
device
=
None
,
dtype
=
None
,
**
kwargs
):
super
().
__init__
()
super
().
__init__
()
self
.
weight
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
normalized_shape
,
device
=
device
,
dtype
=
dtype
))
self
.
weight
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
normalized_shape
,
device
=
device
,
dtype
=
dtype
))
...
@@ -233,6 +235,7 @@ class RMSNorm(torch.nn.Module):
...
@@ -233,6 +235,7 @@ class RMSNorm(torch.nn.Module):
class
CoreAttention
(
torch
.
nn
.
Module
):
class
CoreAttention
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
config
:
ChatGLMConfig
,
layer_number
):
def
__init__
(
self
,
config
:
ChatGLMConfig
,
layer_number
):
super
(
CoreAttention
,
self
).
__init__
()
super
(
CoreAttention
,
self
).
__init__
()
...
@@ -839,6 +842,7 @@ class Embedding(torch.nn.Module):
...
@@ -839,6 +842,7 @@ class Embedding(torch.nn.Module):
class
ChatGLMModel
(
ChatGLMPreTrainedModel
):
class
ChatGLMModel
(
ChatGLMPreTrainedModel
):
def
__init__
(
self
,
config
:
ChatGLMConfig
,
device
=
None
,
empty_init
=
True
):
def
__init__
(
self
,
config
:
ChatGLMConfig
,
device
=
None
,
empty_init
=
True
):
super
().
__init__
(
config
)
super
().
__init__
(
config
)
if
empty_init
:
if
empty_init
:
...
@@ -921,6 +925,7 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
...
@@ -921,6 +925,7 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
if
inputs_embeds
is
None
:
if
inputs_embeds
is
None
:
inputs_embeds
=
self
.
embedding
(
input_ids
)
inputs_embeds
=
self
.
embedding
(
input_ids
)
print
(
inputs_embeds
)
if
self
.
pre_seq_len
is
not
None
:
if
self
.
pre_seq_len
is
not
None
:
if
past_key_values
is
None
:
if
past_key_values
is
None
:
...
@@ -982,6 +987,7 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
...
@@ -982,6 +987,7 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
class
ChatGLMForConditionalGeneration
(
ChatGLMPreTrainedModel
):
class
ChatGLMForConditionalGeneration
(
ChatGLMPreTrainedModel
):
def
__init__
(
self
,
config
:
ChatGLMConfig
,
empty_init
=
True
,
device
=
None
):
def
__init__
(
self
,
config
:
ChatGLMConfig
,
empty_init
=
True
,
device
=
None
):
super
().
__init__
(
config
)
super
().
__init__
(
config
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment