Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
text-generation-inference
Commits
82670d97
Unverified
Commit
82670d97
authored
Dec 12, 2023
by
OlivierDehaene
Committed by
GitHub
Dec 12, 2023
Browse files
feat: add quant to mixtral (#1337)
parent
ec6d4592
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
184 additions
and
35 deletions
+184
-35
server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py
...n_server/models/custom_modeling/flash_mistral_modeling.py
+1
-3
server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py
...n_server/models/custom_modeling/flash_mixtral_modeling.py
+161
-14
server/text_generation_server/models/flash_mistral.py
server/text_generation_server/models/flash_mistral.py
+20
-14
server/text_generation_server/utils/layers.py
server/text_generation_server/utils/layers.py
+2
-4
No files found.
server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py
View file @
82670d97
...
...
@@ -434,8 +434,6 @@ class FlashMistralForCausalLM(torch.nn.Module):
weights
=
weights
,
)
self
.
max_past
=
config
.
sliding_window
if
self
.
max_past
is
None
:
raise
ValueError
(
"max_past cannot be None"
)
def
forward
(
self
,
...
...
@@ -454,7 +452,7 @@ class FlashMistralForCausalLM(torch.nn.Module):
if
prefill_cache_indices
is
not
None
:
# Slots also need to be sliced as it has the same size as the whole kv tensor
slots
=
slots
[
prefill_cache_indices
]
el
s
e
:
el
if
self
.
max_past
is
not
Non
e
:
# Clamp in decode mode as paged attention requires clamped values whereas the flash attention
# kernel requires the true values
max_s
=
min
(
self
.
max_past
,
max_s
)
...
...
server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py
View file @
82670d97
...
...
@@ -365,9 +365,9 @@ class BlockSparseMoE(nn.Module):
self
.
gate
=
FastLinear
.
load
(
config
,
f
"
{
prefix
}
.gate"
,
weights
,
bias
=
False
)
# merged expert weights, all of size (n_experts * ffn_dim, hidden_dim)
self
.
w1
=
_load_experts
(
config
,
f
"
{
prefix
}
.experts"
,
"w1"
,
weights
)
.
t
()
self
.
w1
=
_load_experts
(
config
,
f
"
{
prefix
}
.experts"
,
"w1"
,
weights
)
self
.
w2
=
_load_experts
(
config
,
f
"
{
prefix
}
.experts"
,
"w2"
,
weights
)
self
.
w3
=
_load_experts
(
config
,
f
"
{
prefix
}
.experts"
,
"w3"
,
weights
)
.
t
()
self
.
w3
=
_load_experts
(
config
,
f
"
{
prefix
}
.experts"
,
"w3"
,
weights
)
self
.
offsets
=
None
self
.
offsets_block_rows
=
0
...
...
@@ -467,8 +467,7 @@ class BlockSparseMoE(nn.Module):
return
indices
,
bin_ids
,
bins
,
padded_bins
,
tokens_per_expert
@
torch
.
inference_mode
()
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
sparse_forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""
x: (sequence_length, model_dim)
gate_logits: (sequence_length, n_experts)
...
...
@@ -502,8 +501,8 @@ class BlockSparseMoE(nn.Module):
# (top_k * sequence_length + padding, ffn_dim * n_experts)
x
=
stk
.
Matrix
(
topo
.
size
(),
self
.
act
(
stk
.
ops
.
sdd
(
x
,
self
.
w1
,
topo
).
data
)
*
stk
.
ops
.
sdd
(
x
,
self
.
w3
,
topo
).
data
,
self
.
act
(
stk
.
ops
.
sdd
(
x
,
self
.
w1
.
t
()
,
topo
).
data
)
*
stk
.
ops
.
sdd
(
x
,
self
.
w3
.
t
()
,
topo
).
data
,
topo
.
row_indices
,
topo
.
column_indices
,
topo
.
offsets
,
...
...
@@ -534,6 +533,156 @@ class BlockSparseMoE(nn.Module):
return
x
.
view
(
*
input_shape
)
def
dense_forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""
x: (sequence_length, model_dim)
gate_logits: (sequence_length, n_experts)
"""
# optional reshape
input_shape
=
x
.
shape
x
=
x
.
view
(
-
1
,
input_shape
[
-
1
])
# gate_logits: (sequence_length, n_experts)
gate_logits
=
self
.
gate
(
x
)
# all_probs: (sequence_length, n_experts) and upcast for softmax
all_probs
=
torch
.
nn
.
functional
.
softmax
(
gate_logits
,
dim
=
1
,
dtype
=
torch
.
float
)
if
self
.
top_k
<
self
.
num_experts
:
_
,
not_selected_experts
=
torch
.
topk
(
all_probs
,
self
.
num_experts
-
self
.
top_k
,
largest
=
False
,
sorted
=
False
,
dim
=
1
,
)
# Mask not selected experts
all_probs
.
scatter_
(
1
,
not_selected_experts
,
0
)
# Re-normalize
weights
=
all_probs
/
all_probs
.
sum
(
dim
=
1
,
keepdim
=
True
)
# Expand to [num_experts, sequence_length, model_dim]
x
=
x
.
view
(
1
,
-
1
,
input_shape
[
-
1
]).
expand
(
self
.
num_experts
,
-
1
,
input_shape
[
-
1
])
# Permute to [num_experts, model_dim, ffn_dim]
w1
=
self
.
w1
.
view
(
self
.
num_experts
,
self
.
ffn_dim
,
self
.
hidden_dim
).
permute
(
0
,
2
,
1
)
w3
=
self
.
w3
.
view
(
self
.
num_experts
,
self
.
ffn_dim
,
self
.
hidden_dim
).
permute
(
0
,
2
,
1
)
inter
=
self
.
act
(
torch
.
bmm
(
x
,
w1
))
*
torch
.
bmm
(
x
,
w3
)
out
=
torch
.
bmm
(
inter
,
self
.
w2
.
view
(
self
.
num_experts
,
self
.
ffn_dim
,
self
.
hidden_dim
)
)
# Mask not selected experts
out
*=
weights
.
t
().
view
(
self
.
num_experts
,
-
1
,
1
)
# Sum experts
out
=
out
.
sum
(
0
)
# Reduce sum
if
self
.
process_group
.
size
()
>
1
:
torch
.
distributed
.
all_reduce
(
out
,
group
=
self
.
process_group
)
return
out
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
if
len
(
x
)
>
256
:
return
self
.
sparse_forward
(
x
)
# This is faster when there is not a lot of tokens
return
self
.
dense_forward
(
x
)
class
DenseMoE
(
nn
.
Module
):
def
__init__
(
self
,
prefix
,
config
:
MixtralConfig
,
weights
):
super
().
__init__
()
self
.
hidden_dim
=
config
.
hidden_size
self
.
ffn_dim
=
config
.
intermediate_size
//
weights
.
process_group
.
size
()
self
.
num_experts
=
config
.
num_local_experts
self
.
top_k
=
config
.
num_experts_per_tok
act
=
config
.
hidden_act
if
"gelu"
in
act
:
self
.
act
=
lambda
x
:
torch
.
nn
.
functional
.
gelu
(
x
,
approximate
=
"tanh"
if
act
in
[
"gelu_fast"
,
"gelu_pytorch_tanh"
]
else
"none"
,
)
elif
"silu"
in
act
:
self
.
act
=
torch
.
nn
.
functional
.
silu
else
:
self
.
act
=
ACT2FN
[
act
]
# gating
self
.
gate
=
FastLinear
.
load
(
config
,
f
"
{
prefix
}
.gate"
,
weights
,
bias
=
False
)
self
.
w1
=
[
TensorParallelColumnLinear
.
load
(
config
,
prefix
=
f
"
{
prefix
}
.experts.
{
i
}
.w1"
,
weights
=
weights
,
bias
=
False
)
for
i
in
range
(
self
.
num_experts
)
]
self
.
w3
=
[
TensorParallelColumnLinear
.
load
(
config
,
prefix
=
f
"
{
prefix
}
.experts.
{
i
}
.w3"
,
weights
=
weights
,
bias
=
False
)
for
i
in
range
(
self
.
num_experts
)
]
self
.
w2
=
[
TensorParallelRowLinear
.
load
(
config
,
prefix
=
f
"
{
prefix
}
.experts.
{
i
}
.w2"
,
weights
=
weights
,
bias
=
False
)
for
i
in
range
(
self
.
num_experts
)
]
self
.
process_group
=
weights
.
process_group
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""
x: (sequence_length, model_dim)
gate_logits: (sequence_length, n_experts)
"""
# optional reshape
input_shape
=
x
.
shape
x
=
x
.
view
(
-
1
,
input_shape
[
-
1
])
# gate_logits: (sequence_length, n_experts)
gate_logits
=
self
.
gate
(
x
)
# all_probs: (sequence_length, n_experts) and upcast for softmax
all_probs
=
torch
.
nn
.
functional
.
softmax
(
gate_logits
,
dim
=
1
,
dtype
=
torch
.
float
)
if
self
.
top_k
<
self
.
num_experts
:
_
,
not_selected_experts
=
torch
.
topk
(
all_probs
,
self
.
num_experts
-
self
.
top_k
,
largest
=
False
,
sorted
=
False
,
dim
=
1
,
)
# Mask not selected experts
all_probs
.
scatter_
(
1
,
not_selected_experts
,
0
)
# Re-normalize
weights
=
all_probs
/
all_probs
.
sum
(
dim
=
1
,
keepdim
=
True
)
# Final output tensor
out
=
x
.
new_zeros
(
x
.
shape
[
0
],
self
.
hidden_dim
)
for
i
in
range
(
self
.
num_experts
):
h
=
self
.
act
(
self
.
w1
[
i
](
x
))
*
self
.
w3
[
i
](
x
)
h
=
self
.
w2
[
i
](
h
,
reduce
=
False
)
# Add expert output to out with masking
out
+=
h
*
weights
[:,
i
].
view
(
-
1
,
1
)
# Reduce sum
if
self
.
process_group
.
size
()
>
1
:
torch
.
distributed
.
all_reduce
(
out
,
group
=
self
.
process_group
)
return
out
class
MixtralLayer
(
nn
.
Module
):
def
__init__
(
self
,
layer_id
,
config
,
weights
):
...
...
@@ -543,9 +692,9 @@ class MixtralLayer(nn.Module):
self
.
self_attn
=
MixtralAttention
(
prefix
=
f
"
{
prefix
}
.self_attn"
,
config
=
config
,
weights
=
weights
)
self
.
block_sparse_moe
=
BlockSparseMoE
(
f
"
{
prefix
}
.b
lock
_s
parse
_moe"
,
config
,
weights
)
moe_cls
=
B
lock
S
parse
MoE
if
config
.
quantize
is
None
else
DenseMoE
self
.
moe
=
moe_cls
(
f
"
{
prefix
}
.block_sparse_moe"
,
config
,
weights
)
self
.
input_layernorm
=
FastRMSNorm
.
load
(
prefix
=
f
"
{
prefix
}
.input_layernorm"
,
weights
=
weights
,
eps
=
config
.
rms_norm_eps
...
...
@@ -591,9 +740,9 @@ class MixtralLayer(nn.Module):
attn_output
,
res
)
block_sparse_
moe_output
=
self
.
block_sparse_
moe
(
normed_attn_res_output
)
moe_output
=
self
.
moe
(
normed_attn_res_output
)
return
block_sparse_
moe_output
,
attn_res
return
moe_output
,
attn_res
class
MixtralModel
(
torch
.
nn
.
Module
):
...
...
@@ -675,8 +824,6 @@ class FlashMixtralForCausalLM(torch.nn.Module):
weights
=
weights
,
)
self
.
max_past
=
config
.
sliding_window
if
self
.
max_past
is
None
:
raise
ValueError
(
"max_past cannot be None"
)
def
forward
(
self
,
...
...
@@ -695,7 +842,7 @@ class FlashMixtralForCausalLM(torch.nn.Module):
if
prefill_cache_indices
is
not
None
:
# Slots also need to be sliced as it has the same size as the whole kv tensor
slots
=
slots
[
prefill_cache_indices
]
el
s
e
:
el
if
self
.
max_past
is
not
Non
e
:
# Clamp in decode mode as paged attention requires clamped values whereas the flash attention
# kernel requires the true values
max_s
=
min
(
self
.
max_past
,
max_s
)
...
...
server/text_generation_server/models/flash_mistral.py
View file @
82670d97
...
...
@@ -136,9 +136,9 @@ class FlashMistralBatch(FlashCausalLMBatch):
total_tokens
=
input_length
+
max_new_tokens
-
1
+
speculative_length
# Needed blocks can not go over SLIDING_WINDOW_BLOCKS
needed_blocks
=
m
in
(
math
.
ceil
(
total_tokens
/
BLOCK_SIZE
),
SLIDING_WINDOW_BLOCKS
)
needed_blocks
=
m
ath
.
ceil
(
total_tokens
/
BLOCK_SIZE
)
if
SLIDING_WINDOW_BLOCKS
is
not
None
:
needed_blocks
=
min
(
needed_blocks
,
SLIDING_WINDOW_BLOCKS
)
blocks
+=
needed_blocks
needed_blocks_slots
.
append
((
needed_blocks
,
total_tokens
))
...
...
@@ -152,6 +152,7 @@ class FlashMistralBatch(FlashCausalLMBatch):
slot_indices
.
append
(
request_slot_indices
)
# Create tensor to slice into the kv tensor in prefill
if
SLIDING_WINDOW
is
not
None
:
request_prefill_cache_indices
=
torch
.
arange
(
cumulative_length
+
max
(
0
,
input_length
-
SLIDING_WINDOW
),
cumulative_length
+
input_length
,
...
...
@@ -209,11 +210,13 @@ class FlashMistralBatch(FlashCausalLMBatch):
input_ids
=
np
.
concatenate
(
all_input_ids
,
dtype
=
np
.
int64
)
position_ids
=
torch
.
cat
(
position_ids
)
slot_indices
=
torch
.
cat
(
slot_indices
)
if
SLIDING_WINDOW
is
not
None
:
prefill_cache_indices
=
torch
.
cat
(
prefill_cache_indices
)
else
:
input_ids
=
all_input_ids
[
0
]
position_ids
=
position_ids
[
0
]
slot_indices
=
slot_indices
[
0
]
if
SLIDING_WINDOW
is
not
None
:
prefill_cache_indices
=
prefill_cache_indices
[
0
]
cu_seqlen_prefill
=
torch
.
tensor
(
...
...
@@ -222,7 +225,9 @@ class FlashMistralBatch(FlashCausalLMBatch):
position_ids
=
position_ids
.
to
(
device
)
slot_indices
=
slot_indices
.
to
(
device
)
prefill_cache_indices
=
prefill_cache_indices
.
to
(
device
)
prefill_cache_indices
=
(
prefill_cache_indices
.
to
(
device
)
if
SLIDING_WINDOW
is
not
None
else
None
)
input_ids
=
torch
.
tensor
(
input_ids
,
dtype
=
torch
.
int64
,
device
=
device
)
input_lengths_tensor
=
torch
.
tensor
(
input_lengths
,
dtype
=
torch
.
int32
,
device
=
device
...
...
@@ -314,6 +319,7 @@ class BaseFlashMistral(FlashCausalLM):
config
.
quantize
=
quantize
# Set context windows
if
config
.
sliding_window
is
not
None
:
SLIDING_WINDOW
=
config
.
sliding_window
SLIDING_WINDOW_BLOCKS
=
math
.
ceil
(
config
.
sliding_window
/
BLOCK_SIZE
)
...
...
server/text_generation_server/utils/layers.py
View file @
82670d97
...
...
@@ -64,8 +64,6 @@ elif CAN_EXLLAMA:
except
ImportError
:
pass
from
typing
import
Optional
HAS_EETQ
=
False
try
:
from
EETQ
import
quant_weights
,
w8_a16_gemm
...
...
@@ -489,9 +487,9 @@ class TensorParallelRowLinear(SuperLayer):
process_group
=
weights
.
process_group
,
)
def
forward
(
self
,
input
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
forward
(
self
,
input
:
torch
.
Tensor
,
reduce
:
bool
=
True
)
->
torch
.
Tensor
:
out
=
super
().
forward
(
input
)
if
self
.
process_group
.
size
()
>
1
:
if
self
.
process_group
.
size
()
>
1
and
reduce
:
torch
.
distributed
.
all_reduce
(
out
,
group
=
self
.
process_group
)
return
out
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment