Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
text-generation-inference
Commits
211b211e
Unverified
Commit
211b211e
authored
Jul 18, 2023
by
Nicolas Patry
Committed by
GitHub
Jul 18, 2023
Browse files
feat(server): add support for llamav2 (#633)
parent
3b71c385
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
58 additions
and
20 deletions
+58
-20
server/text_generation_server/models/custom_modeling/flash_llama_modeling.py
...ion_server/models/custom_modeling/flash_llama_modeling.py
+57
-19
server/text_generation_server/models/flash_llama.py
server/text_generation_server/models/flash_llama.py
+1
-1
No files found.
server/text_generation_server/models/custom_modeling/flash_llama_modeling.py
View file @
211b211e
...
...
@@ -39,6 +39,7 @@ from text_generation_server.utils.layers import (
TensorParallelEmbedding
,
PositionRotaryEmbedding
,
TensorParallelHead
,
get_linear
,
)
...
...
@@ -59,7 +60,8 @@ class LlamaRMSNorm(nn.Module):
hidden_states
+=
residual
residual
=
hidden_states
variance
=
hidden_states
.
to
(
torch
.
float32
).
pow
(
2
).
mean
(
-
1
,
keepdim
=
True
)
hidden_states
=
hidden_states
.
to
(
torch
.
float32
)
variance
=
hidden_states
.
pow
(
2
).
mean
(
-
1
,
keepdim
=
True
)
hidden_states
=
hidden_states
*
torch
.
rsqrt
(
variance
+
self
.
variance_epsilon
)
...
...
@@ -94,6 +96,27 @@ class LlamaRMSNorm(nn.Module):
return
normed_hidden_states
,
res
def
_load_gqa
(
config
,
prefix
:
str
,
weights
):
w
=
[
weights
.
get_sharded
(
f
"
{
prefix
}
.q_proj.weight"
,
dim
=
0
),
weights
.
get_sharded
(
f
"
{
prefix
}
.k_proj.weight"
,
dim
=
0
),
weights
.
get_sharded
(
f
"
{
prefix
}
.v_proj.weight"
,
dim
=
0
),
]
weight
=
torch
.
cat
(
w
,
dim
=
0
)
weight
=
weight
.
to
(
dtype
=
weights
.
dtype
).
to
(
device
=
weights
.
device
)
bias
=
None
assert
config
.
hidden_size
%
config
.
num_attention_heads
==
0
head_size
=
config
.
hidden_size
//
config
.
num_attention_heads
assert
config
.
num_attention_heads
%
weights
.
process_group
.
size
()
==
0
num_heads
=
config
.
num_attention_heads
//
weights
.
process_group
.
size
()
num_key_value_heads
=
config
.
num_key_value_heads
//
weights
.
process_group
.
size
()
assert
list
(
weight
.
shape
)
==
[
(
num_heads
+
2
*
num_key_value_heads
)
*
head_size
,
config
.
hidden_size
,
],
f
"
{
list
(
weight
.
shape
)
}
!=
{
[(
num_heads
+
2
*
config
.
num_key_value_heads
)
*
head_size
,
config
.
hidden_size
]
}
"
return
TensorParallelColumnLinear
(
get_linear
(
weight
,
bias
,
config
.
quantize
))
class
FlashLlamaAttention
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
...
...
@@ -118,22 +141,29 @@ class FlashLlamaAttention(torch.nn.Module):
f
"and `num_shards`:
{
weights
.
process_group
.
size
()
}
"
)
self
.
num_heads
=
self
.
num_heads
//
weights
.
process_group
.
size
()
self
.
query_key_value
=
TensorParallelColumnLinear
.
load_multi
(
config
,
prefixes
=
[
f
"
{
prefix
}
.q_proj"
,
f
"
{
prefix
}
.k_proj"
,
f
"
{
prefix
}
.v_proj"
],
dim
=
0
,
weights
=
weights
,
bias
=
False
,
self
.
num_key_value_heads
=
(
config
.
num_key_value_heads
//
weights
.
process_group
.
size
()
)
if
config
.
num_attention_heads
!=
config
.
num_key_value_heads
:
self
.
query_key_value
=
_load_gqa
(
config
,
prefix
,
weights
)
else
:
self
.
query_key_value
=
TensorParallelColumnLinear
.
load_multi
(
config
,
prefixes
=
[
f
"
{
prefix
}
.q_proj"
,
f
"
{
prefix
}
.k_proj"
,
f
"
{
prefix
}
.v_proj"
],
dim
=
0
,
weights
=
weights
,
bias
=
False
,
)
self
.
o_proj
=
TensorParallelRowLinear
.
load
(
config
,
prefix
=
f
"
{
prefix
}
.o_proj"
,
weights
=
weights
,
bias
=
False
,
)
self
.
num_groups
=
self
.
num_heads
//
self
.
num_key_value_heads
self
.
kv_head_mapping
=
torch
.
arange
(
0
,
self
.
num_heads
,
dtype
=
torch
.
int32
,
device
=
weights
.
device
)
0
,
self
.
num_
key_value_
heads
,
dtype
=
torch
.
int32
,
device
=
weights
.
device
)
.
repeat_interleave
(
self
.
num_groups
)
def
forward
(
self
,
...
...
@@ -148,26 +178,33 @@ class FlashLlamaAttention(torch.nn.Module):
max_s
,
):
qkv
=
self
.
query_key_value
(
hidden_states
)
qkv
=
qkv
.
view
(
-
1
,
3
,
self
.
num_heads
,
self
.
head_size
)
query
,
kv
=
qkv
.
split
(
[
self
.
head_size
*
self
.
num_heads
,
2
*
self
.
head_size
*
self
.
num_key_value_heads
,
],
dim
=
1
,
)
query
=
query
.
view
(
-
1
,
self
.
num_heads
,
self
.
head_size
)
kv
=
kv
.
view
(
-
1
,
2
,
self
.
num_key_value_heads
,
self
.
head_size
)
# Inplace rotary
self
.
rotary_emb
(
qkv
[:,
0
],
cos
,
sin
)
self
.
rotary_emb
(
qkv
[:,
1
],
cos
,
sin
)
self
.
rotary_emb
(
query
,
cos
,
sin
)
self
.
rotary_emb
(
torch
.
select
(
kv
,
dim
=
1
,
index
=
0
),
cos
,
sin
)
vllm_cache_ops
.
reshape_and_cache
(
q
kv
[:,
1
],
q
kv
[:,
2
],
kv_cache
[
0
],
kv_cache
[
1
],
slots
kv
[:,
0
],
kv
[:,
1
],
kv_cache
[
0
],
kv_cache
[
1
],
slots
)
# output tensor
attn_output
=
torch
.
empty_like
(
q
kv
[:,
0
]
)
attn_output
=
torch
.
empty_like
(
q
uery
)
# Prefill
if
cu_seqlen_prefill
is
not
None
:
# flash attention
attention
(
q
kv
[:,
0
]
,
qkv
[:,
1
]
,
qkv
[:,
2
]
,
q
uery
,
torch
.
select
(
kv
,
dim
=
1
,
index
=
0
)
,
torch
.
select
(
kv
,
dim
=
1
,
index
=
1
)
,
attn_output
,
cu_seqlen_prefill
,
max_s
,
...
...
@@ -179,7 +216,7 @@ class FlashLlamaAttention(torch.nn.Module):
block_size
=
kv_cache
[
1
].
shape
[
3
]
vllm_attention_ops
.
single_query_cached_kv_attention
(
attn_output
,
q
kv
[:,
0
]
,
q
uery
,
kv_cache
[
0
],
kv_cache
[
1
],
self
.
kv_head_mapping
,
...
...
@@ -316,6 +353,7 @@ class FlashLlamaModel(torch.nn.Module):
self
.
head_size
=
self
.
layers
[
0
].
self_attn
.
head_size
self
.
num_heads
=
self
.
layers
[
0
].
self_attn
.
num_heads
self
.
num_key_value_heads
=
self
.
layers
[
0
].
self_attn
.
num_key_value_heads
def
forward
(
self
,
...
...
server/text_generation_server/models/flash_llama.py
View file @
211b211e
...
...
@@ -69,7 +69,7 @@ class FlashLlama(FlashCausalLM):
model
=
model
,
tokenizer
=
tokenizer
,
num_layers
=
len
(
model
.
model
.
layers
),
num_kv_heads
=
model
.
model
.
num_heads
,
num_kv_heads
=
model
.
model
.
num_
key_value_
heads
,
head_size
=
model
.
model
.
head_size
,
dtype
=
dtype
,
device
=
device
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment