Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
text-generation-inference
Commits
3b41e93a
Unverified
Commit
3b41e93a
authored
Jul 19, 2024
by
Daniël de Kok
Committed by
GitHub
Jul 19, 2024
Browse files
Hotfix: fix MPT after recent refactor (#2257)
parent
18db78f2
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
26 additions
and
21 deletions
+26
-21
server/text_generation_server/models/causal_lm.py
server/text_generation_server/models/causal_lm.py
+7
-2
server/text_generation_server/models/custom_modeling/mpt_modeling.py
..._generation_server/models/custom_modeling/mpt_modeling.py
+19
-19
No files found.
server/text_generation_server/models/causal_lm.py
View file @
3b41e93a
...
@@ -492,7 +492,7 @@ class CausalLMBatch(Batch):
...
@@ -492,7 +492,7 @@ class CausalLMBatch(Batch):
@
dataclass
@
dataclass
class
CausalLMBatchKeysLast
(
Batch
):
class
CausalLMBatchKeysLast
(
CausalLM
Batch
):
keys_head_dim_last
:
bool
=
False
keys_head_dim_last
:
bool
=
False
...
@@ -544,7 +544,12 @@ class CausalLM(Model):
...
@@ -544,7 +544,12 @@ class CausalLM(Model):
config
.
quantize
=
quantize
config
.
quantize
=
quantize
config
.
speculator
=
speculator
config
.
speculator
=
speculator
if
tokenizer
.
pad_token_id
is
None
:
if
tokenizer
.
pad_token_id
is
None
:
tokenizer
.
pad_token_id
=
config
.
pad_token_id
if
config
.
pad_token_id
is
not
None
:
tokenizer
.
pad_token_id
=
config
.
pad_token_id
elif
config
.
eos_token_id
is
not
None
:
tokenizer
.
pad_token_id
=
config
.
eos_token_id
elif
tokenizer
.
eos_token_id
is
not
None
:
tokenizer
.
pad_token_id
=
tokenizer
.
eos_token_id
torch
.
distributed
.
barrier
(
group
=
self
.
process_group
)
torch
.
distributed
.
barrier
(
group
=
self
.
process_group
)
weights_loader
=
get_loader
(
weights_loader
=
get_loader
(
...
...
server/text_generation_server/models/custom_modeling/mpt_modeling.py
View file @
3b41e93a
...
@@ -337,17 +337,17 @@ class MultiheadAttention(nn.Module):
...
@@ -337,17 +337,17 @@ class MultiheadAttention(nn.Module):
weights
,
weights
,
):
):
super
().
__init__
()
super
().
__init__
()
attn_impl
=
config
.
attn_config
[
"
attn_impl
"
]
attn_impl
=
config
.
attn_config
.
attn_impl
self
.
attn_impl
=
config
.
attn_config
[
"
attn_impl
"
]
self
.
attn_impl
=
config
.
attn_config
.
attn_impl
self
.
clip_qkv
=
config
.
attn_config
[
"
clip_qkv
"
]
self
.
clip_qkv
=
config
.
attn_config
.
clip_qkv
self
.
qk_ln
=
config
.
attn_config
[
"
qk_ln
"
]
self
.
qk_ln
=
config
.
attn_config
.
qk_ln
self
.
d_model
=
config
.
d_model
self
.
d_model
=
config
.
d_model
d_model
=
config
.
d_model
d_model
=
config
.
d_model
self
.
n_heads
=
config
.
n_heads
self
.
n_heads
=
config
.
n_heads
self
.
softmax_scale
=
config
.
attn_config
[
"
softmax_scale
"
]
self
.
softmax_scale
=
config
.
attn_config
.
softmax_scale
if
self
.
softmax_scale
is
None
:
if
self
.
softmax_scale
is
None
:
self
.
softmax_scale
=
1
/
math
.
sqrt
(
self
.
d_model
/
self
.
n_heads
)
self
.
softmax_scale
=
1
/
math
.
sqrt
(
self
.
d_model
/
self
.
n_heads
)
self
.
attn_dropout_p
=
config
.
attn_config
[
"
attn_pdrop
"
]
self
.
attn_dropout_p
=
config
.
attn_config
.
attn_pdrop
if
self
.
n_heads
%
weights
.
process_group
.
size
()
!=
0
:
if
self
.
n_heads
%
weights
.
process_group
.
size
()
!=
0
:
raise
ValueError
(
raise
ValueError
(
...
@@ -430,17 +430,17 @@ class MultiQueryAttention(nn.Module):
...
@@ -430,17 +430,17 @@ class MultiQueryAttention(nn.Module):
def
__init__
(
self
,
config
,
prefix
,
weights
):
def
__init__
(
self
,
config
,
prefix
,
weights
):
super
().
__init__
()
super
().
__init__
()
attn_impl
=
config
.
attn_config
[
"
attn_impl
"
]
attn_impl
=
config
.
attn_config
.
attn_impl
self
.
attn_impl
=
config
.
attn_config
[
"
attn_impl
"
]
self
.
attn_impl
=
config
.
attn_config
.
attn_impl
self
.
clip_qkv
=
config
.
attn_config
[
"
clip_qkv
"
]
self
.
clip_qkv
=
config
.
attn_config
.
clip_qkv
self
.
qk_ln
=
config
.
attn_config
[
"
qk_ln
"
]
self
.
qk_ln
=
config
.
attn_config
.
qk_ln
self
.
d_model
=
config
.
d_model
self
.
d_model
=
config
.
d_model
d_model
=
config
.
d_model
d_model
=
config
.
d_model
self
.
n_heads
=
config
.
n_heads
self
.
n_heads
=
config
.
n_heads
self
.
softmax_scale
=
config
.
attn_config
[
"
softmax_scale
"
]
self
.
softmax_scale
=
config
.
attn_config
.
softmax_scale
if
self
.
softmax_scale
is
None
:
if
self
.
softmax_scale
is
None
:
self
.
softmax_scale
=
1
/
math
.
sqrt
(
self
.
head_dim
)
self
.
softmax_scale
=
1
/
math
.
sqrt
(
self
.
head_dim
)
self
.
attn_dropout_p
=
config
.
attn_config
[
"
attn_pdrop
"
]
self
.
attn_dropout_p
=
config
.
attn_config
.
attn_pdrop
# self.Wqkv = nn.Linear(d_model, d_model + 2 * self.head_dim, device=device)
# self.Wqkv = nn.Linear(d_model, d_model + 2 * self.head_dim, device=device)
self
.
Wqkv
=
TensorParallelColumnLinear
.
load
(
self
.
Wqkv
=
TensorParallelColumnLinear
.
load
(
config
,
prefix
=
f
"
{
prefix
}
.Wqkv"
,
weights
=
weights
,
bias
=
not
config
.
no_bias
config
,
prefix
=
f
"
{
prefix
}
.Wqkv"
,
weights
=
weights
,
bias
=
not
config
.
no_bias
...
@@ -614,9 +614,9 @@ class MPTBlock(nn.Module):
...
@@ -614,9 +614,9 @@ class MPTBlock(nn.Module):
def
__init__
(
self
,
config
,
prefix
,
weights
):
def
__init__
(
self
,
config
,
prefix
,
weights
):
super
().
__init__
()
super
().
__init__
()
self
.
prefix
=
prefix
self
.
prefix
=
prefix
if
config
.
attn_config
[
"
attn_type
"
]
!=
"multihead_attention"
:
if
config
.
attn_config
.
attn_type
!=
"multihead_attention"
:
raise
NotImplementedError
(
raise
NotImplementedError
(
f
"""Not implemented attn
{
config
.
attn_config
[
"
attn_type
"
]
}
"""
f
"""Not implemented attn
{
config
.
attn_config
.
attn_type
}
"""
)
)
resid_pdrop
=
config
.
resid_pdrop
resid_pdrop
=
config
.
resid_pdrop
if
config
.
no_bias
:
if
config
.
no_bias
:
...
@@ -789,11 +789,11 @@ class MPTModel(MPTPreTrainedModel):
...
@@ -789,11 +789,11 @@ class MPTModel(MPTPreTrainedModel):
self
.
world_size
=
weights
.
process_group
.
size
()
self
.
world_size
=
weights
.
process_group
.
size
()
self
.
rank
=
weights
.
process_group
.
rank
()
self
.
rank
=
weights
.
process_group
.
rank
()
self
.
n_heads
=
config
.
n_heads
self
.
n_heads
=
config
.
n_heads
self
.
attn_impl
=
config
.
attn_config
[
"
attn_impl
"
]
self
.
attn_impl
=
config
.
attn_config
.
attn_impl
self
.
prefix_lm
=
config
.
attn_config
[
"
prefix_lm
"
]
self
.
prefix_lm
=
config
.
attn_config
.
prefix_lm
self
.
attn_uses_sequence_id
=
config
.
attn_config
[
"
attn_uses_sequence_id
"
]
self
.
attn_uses_sequence_id
=
config
.
attn_config
.
attn_uses_sequence_id
self
.
alibi
=
config
.
attn_config
[
"
alibi
"
]
self
.
alibi
=
config
.
attn_config
.
alibi
self
.
alibi_bias_max
=
config
.
attn_config
[
"
alibi_bias_max
"
]
self
.
alibi_bias_max
=
config
.
attn_config
.
alibi_bias_max
if
config
.
init_device
==
"mixed"
:
if
config
.
init_device
==
"mixed"
:
if
dist
.
get_local_rank
()
==
0
:
if
dist
.
get_local_rank
()
==
0
:
config
.
init_device
=
"cpu"
config
.
init_device
=
"cpu"
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment