Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
text-generation-inference
Commits
f04255c6
Unverified
Commit
f04255c6
authored
Mar 29, 2024
by
OlivierDehaene
Committed by
GitHub
Mar 29, 2024
Browse files
feat: Add dbrx support (#1685)
Close #1679
parent
762dbf3f
Changes
4
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
1180 additions
and
0 deletions
+1180
-0
server/text_generation_server/models/__init__.py
server/text_generation_server/models/__init__.py
+24
-0
server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py
...tion_server/models/custom_modeling/flash_dbrx_modeling.py
+1055
-0
server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py
...n_server/models/custom_modeling/flash_mixtral_modeling.py
+2
-0
server/text_generation_server/models/flash_dbrx.py
server/text_generation_server/models/flash_dbrx.py
+99
-0
No files found.
server/text_generation_server/models/__init__.py
View file @
f04255c6
...
@@ -71,6 +71,7 @@ try:
...
@@ -71,6 +71,7 @@ try:
from
text_generation_server.models.flash_mixtral
import
FlashMixtral
from
text_generation_server.models.flash_mixtral
import
FlashMixtral
from
text_generation_server.models.flash_phi
import
FlashPhi
from
text_generation_server.models.flash_phi
import
FlashPhi
from
text_generation_server.models.flash_starcoder2
import
FlashStarcoder2
from
text_generation_server.models.flash_starcoder2
import
FlashStarcoder2
from
text_generation_server.models.flash_dbrx
import
FlashDbrx
from
text_generation_server.utils.flash_attn
import
HAS_FLASH_ATTN_V2_CUDA
from
text_generation_server.utils.flash_attn
import
HAS_FLASH_ATTN_V2_CUDA
except
ImportError
as
e
:
except
ImportError
as
e
:
...
@@ -86,6 +87,7 @@ if FLASH_ATTENTION:
...
@@ -86,6 +87,7 @@ if FLASH_ATTENTION:
__all__
.
append
(
IDEFICSSharded
)
__all__
.
append
(
IDEFICSSharded
)
__all__
.
append
(
FlashMistral
)
__all__
.
append
(
FlashMistral
)
__all__
.
append
(
FlashMixtral
)
__all__
.
append
(
FlashMixtral
)
__all__
.
append
(
FlashDbrx
)
__all__
.
append
(
FlashPhi
)
__all__
.
append
(
FlashPhi
)
__all__
.
append
(
FlashQwen2
)
__all__
.
append
(
FlashQwen2
)
__all__
.
append
(
FlashStarcoder2
)
__all__
.
append
(
FlashStarcoder2
)
...
@@ -381,6 +383,28 @@ def get_model(
...
@@ -381,6 +383,28 @@ def get_model(
trust_remote_code
=
trust_remote_code
,
trust_remote_code
=
trust_remote_code
,
)
)
if
model_type
==
"dbrx"
:
if
FLASH_ATTENTION
:
return
FlashDbrx
(
model_id
,
revision
,
quantize
=
quantize
,
use_medusa
=
use_medusa
,
dtype
=
dtype
,
trust_remote_code
=
trust_remote_code
,
)
elif
sharded
:
raise
NotImplementedError
(
FLASH_ATT_ERROR_MESSAGE
.
format
(
"Sharded DBRX"
))
else
:
return
CausalLM
(
model_id
,
revision
,
quantize
=
quantize
,
use_medusa
=
use_medusa
,
dtype
=
dtype
,
trust_remote_code
=
trust_remote_code
,
)
if
model_type
in
[
"RefinedWeb"
,
"RefinedWebModel"
,
"falcon"
]:
if
model_type
in
[
"RefinedWeb"
,
"RefinedWebModel"
,
"falcon"
]:
if
sharded
:
if
sharded
:
if
FLASH_ATTENTION
:
if
FLASH_ATTENTION
:
...
...
server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py
0 → 100644
View file @
f04255c6
This diff is collapsed.
Click to expand it.
server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py
View file @
f04255c6
...
@@ -552,6 +552,7 @@ class BlockSparseMoE(nn.Module):
...
@@ -552,6 +552,7 @@ class BlockSparseMoE(nn.Module):
# Re-normalize
# Re-normalize
weights
=
all_probs
/
all_probs
.
sum
(
dim
=
1
,
keepdim
=
True
)
weights
=
all_probs
/
all_probs
.
sum
(
dim
=
1
,
keepdim
=
True
)
weights
=
weights
.
to
(
x
.
dtype
)
# Expand to [num_experts, sequence_length, model_dim]
# Expand to [num_experts, sequence_length, model_dim]
x
=
x
.
view
(
1
,
-
1
,
input_shape
[
-
1
]).
expand
(
self
.
num_experts
,
-
1
,
input_shape
[
-
1
])
x
=
x
.
view
(
1
,
-
1
,
input_shape
[
-
1
]).
expand
(
self
.
num_experts
,
-
1
,
input_shape
[
-
1
])
...
@@ -660,6 +661,7 @@ class DenseMoE(nn.Module):
...
@@ -660,6 +661,7 @@ class DenseMoE(nn.Module):
# Re-normalize
# Re-normalize
weights
=
all_probs
/
all_probs
.
sum
(
dim
=
1
,
keepdim
=
True
)
weights
=
all_probs
/
all_probs
.
sum
(
dim
=
1
,
keepdim
=
True
)
weights
=
weights
.
to
(
x
.
dtype
)
# Final output tensor
# Final output tensor
out
=
x
.
new_zeros
(
x
.
shape
[
0
],
self
.
hidden_dim
)
out
=
x
.
new_zeros
(
x
.
shape
[
0
],
self
.
hidden_dim
)
...
...
server/text_generation_server/models/flash_dbrx.py
0 → 100644
View file @
f04255c6
import
torch
import
torch.distributed
from
opentelemetry
import
trace
from
typing
import
Optional
from
transformers
import
AutoTokenizer
from
transformers.models.gpt2
import
GPT2TokenizerFast
from
text_generation_server.models
import
FlashCausalLM
from
text_generation_server.models.custom_modeling.flash_dbrx_modeling
import
(
FlashDbrxForCausalLM
,
DbrxConfig
,
)
from
text_generation_server.utils
import
(
initialize_torch_distributed
,
weight_files
,
Weights
,
)
tracer
=
trace
.
get_tracer
(
__name__
)
class
FlashDbrx
(
FlashCausalLM
):
def
__init__
(
self
,
model_id
:
str
,
revision
:
Optional
[
str
]
=
None
,
quantize
:
Optional
[
str
]
=
None
,
use_medusa
:
Optional
[
str
]
=
None
,
dtype
:
Optional
[
torch
.
dtype
]
=
None
,
trust_remote_code
:
bool
=
False
,
):
self
.
process_group
,
rank
,
world_size
=
initialize_torch_distributed
()
if
torch
.
cuda
.
is_available
():
device
=
torch
.
device
(
f
"cuda:
{
rank
}
"
)
dtype
=
torch
.
bfloat16
if
dtype
is
None
else
dtype
else
:
raise
NotImplementedError
(
"FlashDBRX is only available on GPU"
)
try
:
tokenizer
=
GPT2TokenizerFast
.
from_pretrained
(
model_id
,
revision
=
revision
,
padding_side
=
"left"
,
truncation_side
=
"left"
,
trust_remote_code
=
trust_remote_code
,
use_fast
=
True
,
from_slow
=
False
,
)
except
:
try
:
tokenizer
=
AutoTokenizer
.
from_pretrained
(
model_id
,
revision
=
revision
,
padding_side
=
"left"
,
truncation_side
=
"left"
,
trust_remote_code
=
trust_remote_code
,
use_fast
=
True
,
from_slow
=
False
,
)
except
:
# FIXME: change back to model id once the tokenizer.json is merged
tokenizer
=
GPT2TokenizerFast
.
from_pretrained
(
"Xenova/dbrx-instruct-tokenizer"
,
revision
=
revision
,
padding_side
=
"left"
,
truncation_side
=
"left"
,
trust_remote_code
=
trust_remote_code
,
use_fast
=
True
,
from_slow
=
False
,
)
config
=
DbrxConfig
.
from_pretrained
(
model_id
,
revision
=
revision
,
trust_remote_code
=
trust_remote_code
)
config
.
quantize
=
quantize
config
.
use_medusa
=
use_medusa
torch
.
distributed
.
barrier
(
group
=
self
.
process_group
)
filenames
=
weight_files
(
model_id
,
revision
=
revision
,
extension
=
".safetensors"
)
weights
=
Weights
(
filenames
,
device
,
dtype
,
process_group
=
self
.
process_group
)
if
config
.
quantize
in
[
"gptq"
,
"awq"
]:
weights
.
_set_gptq_params
(
model_id
,
revision
)
model
=
FlashDbrxForCausalLM
(
config
,
weights
)
torch
.
distributed
.
barrier
(
group
=
self
.
process_group
)
super
(
FlashDbrx
,
self
).
__init__
(
model
=
model
,
tokenizer
=
tokenizer
,
num_layers
=
len
(
model
.
model
.
layers
),
num_kv_heads
=
model
.
model
.
num_key_value_heads
,
head_size
=
model
.
model
.
head_size
,
dtype
=
dtype
,
device
=
device
,
rank
=
rank
,
world_size
=
world_size
,
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment