Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
77a93283
Unverified
Commit
77a93283
authored
Oct 13, 2023
by
Xu Kai
Committed by
GitHub
Oct 13, 2023
Browse files
[inference] add llama2 support (#4898)
* add llama2 support * fix multi group bug
parent
39f2582e
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
153 additions
and
44 deletions
+153
-44
colossalai/inference/tensor_parallel/engine.py
colossalai/inference/tensor_parallel/engine.py
+25
-16
colossalai/inference/tensor_parallel/modeling/_utils.py
colossalai/inference/tensor_parallel/modeling/_utils.py
+1
-1
colossalai/inference/tensor_parallel/modeling/llama.py
colossalai/inference/tensor_parallel/modeling/llama.py
+56
-26
colossalai/kernel/triton/__init__.py
colossalai/kernel/triton/__init__.py
+2
-1
tests/test_infer/test_llama2_infer.py
tests/test_infer/test_llama2_infer.py
+69
-0
No files found.
colossalai/inference/tensor_parallel/engine.py
View file @
77a93283
from
typing
import
Any
,
Callable
,
List
,
Optional
,
Union
import
torch
import
torch.distributed
as
dist
import
torch.nn
as
nn
from
transformers
import
BloomForCausalLM
,
LlamaForCausalLM
from
transformers.generation
import
GenerationConfig
...
...
@@ -74,9 +73,14 @@ class TPInferEngine:
model
.
config
.
num_hidden_layers
if
hasattr
(
model
.
config
,
"num_hidden_layers"
)
else
model
.
config
.
num_layers
)
self
.
layer_num
=
num_hidden_layers
self
.
multi_query_group_num
=
(
model
.
config
.
multi_query_group_num
if
hasattr
(
model
.
config
,
"multi_query_group_num"
)
else
0
)
self
.
multi_query_group_num
=
0
if
hasattr
(
model
.
config
,
"multi_query_group_num"
):
self
.
multi_query_group_num
=
model
.
config
.
multi_query_group_num
if
hasattr
(
model
.
config
,
"num_key_value_heads"
):
self
.
multi_query_group_num
=
model
.
config
.
num_key_value_heads
self
.
tp_size
=
-
1
# to be set with given shard config in self.prepare_shard_config
self
.
cache_manager
=
None
...
...
@@ -97,6 +101,7 @@ class TPInferEngine:
assert
self
.
tp_size
>=
1
,
"TP size not initialized without providing a valid ShardConfig"
assert
self
.
head_num
%
self
.
tp_size
==
0
,
f
"Cannot shard
{
self
.
head_num
}
heads with tp size
{
self
.
tp_size
}
"
self
.
head_num
//=
self
.
tp_size
# update sharded number of heads
if
self
.
multi_query_group_num
:
# NOTE the logic of MQA tensor parallelism should be specified.
assert
(
...
...
@@ -116,13 +121,15 @@ class TPInferEngine:
def
_post_init_gptq_buffer
(
self
,
model
:
nn
.
Module
)
->
None
:
from
colossalai.inference.quant.gptq.cai_gptq
import
CaiQuantLinear
HAS_GPTQ_CUDA
=
False
try
:
from
colossalai.kernel.op_builder.gptq
import
GPTQBuilder
gptq_cuda
=
GPTQBuilder
().
load
()
HAS_GPTQ_CUDA
=
True
except
ImportError
:
warnings
.
warn
(
'
CUDA gptq is not installed
'
)
warnings
.
warn
(
"
CUDA gptq is not installed
"
)
HAS_GPTQ_CUDA
=
False
for
name
,
submodule
in
model
.
named_modules
():
...
...
@@ -130,8 +137,9 @@ class TPInferEngine:
self
.
max_dq_buffer_size
=
max
(
self
.
max_dq_buffer_size
,
submodule
.
qweight
.
numel
()
*
8
)
if
self
.
use_act_order
:
self
.
max_inner_outer_dim
=
max
(
self
.
max_inner_outer_dim
,
submodule
.
infeatures
,
submodule
.
outfeatures
)
self
.
max_inner_outer_dim
=
max
(
self
.
max_inner_outer_dim
,
submodule
.
infeatures
,
submodule
.
outfeatures
)
self
.
bits
=
submodule
.
bits
if
not
(
HAS_GPTQ_CUDA
and
self
.
bits
==
4
):
return
...
...
@@ -141,15 +149,16 @@ class TPInferEngine:
max_input_len
=
self
.
max_input_len
# The temp_state buffer is required to reorder X in the act-order case.
# The temp_dq buffer is required to dequantize weights when using cuBLAS, typically for the prefill.
self
.
gptq_temp_state_buffer
=
torch
.
zeros
((
max_input_len
,
self
.
max_inner_outer_dim
),
dtype
=
torch
.
float16
,
device
=
torch
.
cuda
.
current_device
())
self
.
gptq_temp_dq_buffer
=
torch
.
zeros
((
1
,
self
.
max_dq_buffer_size
),
dtype
=
torch
.
float16
,
device
=
torch
.
cuda
.
current_device
())
gptq_cuda
.
prepare_buffers
(
torch
.
device
(
torch
.
cuda
.
current_device
()),
self
.
gptq_temp_state_buffer
,
self
.
gptq_temp_dq_buffer
)
self
.
gptq_temp_state_buffer
=
torch
.
zeros
(
(
max_input_len
,
self
.
max_inner_outer_dim
),
dtype
=
torch
.
float16
,
device
=
torch
.
cuda
.
current_device
()
)
self
.
gptq_temp_dq_buffer
=
torch
.
zeros
(
(
1
,
self
.
max_dq_buffer_size
),
dtype
=
torch
.
float16
,
device
=
torch
.
cuda
.
current_device
()
)
gptq_cuda
.
prepare_buffers
(
torch
.
device
(
torch
.
cuda
.
current_device
()),
self
.
gptq_temp_state_buffer
,
self
.
gptq_temp_dq_buffer
)
# Using the default from exllama repo here.
matmul_recons_thd
=
8
matmul_fused_remap
=
False
...
...
colossalai/inference/tensor_parallel/modeling/_utils.py
View file @
77a93283
...
...
@@ -45,7 +45,7 @@ def init_to_get_rotary(self, base=10000, use_elem=False):
base
=
float
(
base
)
# NTK ref: https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/
ntk_alpha
=
float
(
os
.
environ
.
get
(
"INFER_NTK_ALPHA"
,
None
)
)
ntk_alpha
=
os
.
environ
.
get
(
"INFER_NTK_ALPHA"
,
None
)
if
ntk_alpha
is
not
None
:
ntk_alpha
=
float
(
ntk_alpha
)
...
...
colossalai/inference/tensor_parallel/modeling/llama.py
View file @
77a93283
...
...
@@ -5,7 +5,13 @@ from transformers.modeling_outputs import BaseModelOutputWithPast
from
transformers.models.llama.modeling_llama
import
LlamaAttention
,
LlamaDecoderLayer
,
LlamaModel
,
LlamaRMSNorm
from
colossalai.inference.tensor_parallel.batch_infer_state
import
BatchInferState
from
colossalai.kernel.triton
import
llama_context_attn_fwd
,
rotary_embedding_fwd
,
token_attention_fwd
from
colossalai.kernel.triton
import
(
llama2_context_attn_fwd
,
llama_context_attn_fwd
,
rotary_embedding_fwd
,
token_attention_fwd
,
)
from
colossalai.kernel.triton.token_attention_kernel
import
Llama2TokenAttentionForwards
from
._utils
import
copy_kv_to_mem_cache
...
...
@@ -138,6 +144,7 @@ class LlamaInferenceForwards:
seq_len
=
infer_state
.
seq_len
infer_state
.
position_cos
=
torch
.
index_select
(
self
.
_cos_cached
,
0
,
seq_len
-
1
).
view
(
seq_len
.
shape
[
0
],
-
1
)
infer_state
.
position_sin
=
torch
.
index_select
(
self
.
_sin_cached
,
0
,
seq_len
-
1
).
view
(
seq_len
.
shape
[
0
],
-
1
)
infer_state
.
other_kv_index
=
infer_state
.
block_loc
[
0
,
seq_length_with_past
-
1
].
item
()
if
inputs_embeds
is
None
:
inputs_embeds
=
self
.
embed_tokens
(
input_ids
)
...
...
@@ -261,8 +268,8 @@ class LlamaInferenceForwards:
# key_states_transposed [bs, num_heads, seq_len, head_dim/embed_size_per_head]
query_states
=
self
.
q_proj
(
hidden_states
).
view
(
bsz
,
q_len
,
self
.
num_heads
,
self
.
head_dim
)
key_states
=
self
.
k_proj
(
hidden_states
).
view
(
bsz
,
q_len
,
self
.
num_heads
,
self
.
head_dim
)
value_states
=
self
.
v_proj
(
hidden_states
).
view
(
bsz
,
q_len
,
self
.
num_heads
,
self
.
head_dim
)
key_states
=
self
.
k_proj
(
hidden_states
).
view
(
bsz
,
q_len
,
self
.
num_
key_value_
heads
,
self
.
head_dim
)
value_states
=
self
.
v_proj
(
hidden_states
).
view
(
bsz
,
q_len
,
self
.
num_
key_value_
heads
,
self
.
head_dim
)
# NOTE might want to revise
# need some way to record the length of past key values cache
...
...
@@ -274,11 +281,11 @@ class LlamaInferenceForwards:
# print("shape ", cos.shape, query_states.view(-1, self.num_heads, self.head_dim).shape, )
rotary_embedding_fwd
(
query_states
.
view
(
-
1
,
self
.
num_heads
,
self
.
head_dim
),
cos
,
sin
)
rotary_embedding_fwd
(
key_states
.
view
(
-
1
,
self
.
num_heads
,
self
.
head_dim
),
cos
,
sin
)
rotary_embedding_fwd
(
key_states
.
view
(
-
1
,
self
.
num_
key_value_
heads
,
self
.
head_dim
),
cos
,
sin
)
query_states
=
query_states
.
reshape
(
-
1
,
self
.
num_heads
,
self
.
head_dim
)
key_states
=
key_states
.
reshape
(
-
1
,
self
.
num_heads
,
self
.
head_dim
)
value_states
=
value_states
.
reshape
(
-
1
,
self
.
num_heads
,
self
.
head_dim
)
key_states
=
key_states
.
reshape
(
-
1
,
self
.
num_
key_value_
heads
,
self
.
head_dim
)
value_states
=
value_states
.
reshape
(
-
1
,
self
.
num_
key_value_
heads
,
self
.
head_dim
)
if
infer_state
.
is_context_stage
:
# first token generation
...
...
@@ -294,6 +301,7 @@ class LlamaInferenceForwards:
attn_output
=
torch
.
empty_like
(
query_states
)
if
self
.
num_key_value_groups
==
1
:
llama_context_attn_fwd
(
query_states
,
key_states
,
...
...
@@ -303,6 +311,16 @@ class LlamaInferenceForwards:
infer_state
.
seq_len
,
infer_state
.
cache_manager
.
past_key_values_length
,
)
else
:
llama2_context_attn_fwd
(
query_states
,
key_states
,
value_states
,
attn_output
,
infer_state
.
start_loc
,
infer_state
.
seq_len
,
infer_state
.
cache_manager
.
past_key_values_length
,
)
else
:
if
infer_state
.
decode_is_contiguous
:
# if decode is contiguous, then we copy to key cache and value cache in cache manager directly
...
...
@@ -330,6 +348,7 @@ class LlamaInferenceForwards:
# (batch_size, seqlen, nheads, headdim)
attn_output
=
torch
.
empty_like
(
query_states
)
if
self
.
num_key_value_groups
==
1
:
token_attention_fwd
(
query_states
,
infer_state
.
cache_manager
.
key_buffer
[
infer_state
.
decode_layer_id
],
...
...
@@ -340,7 +359,18 @@ class LlamaInferenceForwards:
infer_state
.
seq_len
,
infer_state
.
cache_manager
.
past_key_values_length
,
)
else
:
Llama2TokenAttentionForwards
.
token_attn
(
query_states
,
infer_state
.
cache_manager
.
key_buffer
[
infer_state
.
decode_layer_id
],
infer_state
.
cache_manager
.
value_buffer
[
infer_state
.
decode_layer_id
],
attn_output
,
infer_state
.
block_loc
,
infer_state
.
start_loc
,
infer_state
.
seq_len
,
infer_state
.
cache_manager
.
past_key_values_length
,
infer_state
.
other_kv_index
,
)
attn_output
=
attn_output
.
view
(
bsz
,
q_len
,
self
.
hidden_size
)
attn_output
=
self
.
o_proj
(
attn_output
)
...
...
colossalai/kernel/triton/__init__.py
View file @
77a93283
...
...
@@ -9,7 +9,7 @@ except ImportError:
# There may exist import error even if we have triton installed.
if
HAS_TRITON
:
from
.context_attention
import
bloom_context_attn_fwd
,
llama_context_attn_fwd
from
.context_attention
import
bloom_context_attn_fwd
,
llama2_context_attn_fwd
,
llama_context_attn_fwd
from
.copy_kv_cache_dest
import
copy_kv_cache_to_dest
from
.fused_layernorm
import
layer_norm
from
.gptq_triton
import
gptq_fused_linear_triton
...
...
@@ -20,6 +20,7 @@ if HAS_TRITON:
__all__
=
[
"llama_context_attn_fwd"
,
"llama2_context_attn_fwd"
,
"bloom_context_attn_fwd"
,
"softmax"
,
"layer_norm"
,
...
...
tests/test_infer/test_llama2_infer.py
0 → 100644
View file @
77a93283
import
os
import
pytest
import
torch
from
packaging
import
version
from
transformers
import
LlamaForCausalLM
from
transformers.models.llama.configuration_llama
import
LlamaConfig
import
colossalai
from
colossalai.inference.tensor_parallel.engine
import
TPInferEngine
from
colossalai.logging
import
disable_existing_loggers
from
colossalai.shardformer
import
ShardConfig
from
colossalai.testing
import
clear_cache_before_run
,
parameterize
,
rerun_if_address_is_in_use
,
spawn
os
.
environ
[
"TRANSFORMERS_NO_ADVISORY_WARNINGS"
]
=
"true"
TPSIZE
=
2
BATCH_SIZE
=
8
MAX_INPUT_LEN
=
12
MAX_OUTPUT_LEN
=
100
CUDA_SUPPORT
=
version
.
parse
(
torch
.
version
.
cuda
)
>
version
.
parse
(
"11.5"
)
@
parameterize
(
"test_config"
,
[
{
"tp_size"
:
TPSIZE
,
}
],
)
def
run_llama_test
(
test_config
):
llama_config
=
LlamaConfig
(
num_hidden_layers
=
2
,
num_key_value_heads
=
8
,
bos_token_id
=
0
,
eos_token_id
=
1
,
vocab_size
=
1200
,
hidden_size
=
1024
)
model
=
LlamaForCausalLM
(
llama_config
)
model
=
model
.
half
()
shard_config
=
ShardConfig
(
enable_tensor_parallelism
=
True
if
test_config
[
"tp_size"
]
>
1
else
False
,
inference_only
=
True
)
infer_engine
=
TPInferEngine
(
model
,
shard_config
,
BATCH_SIZE
,
MAX_INPUT_LEN
,
MAX_OUTPUT_LEN
)
generate_kwargs
=
dict
(
max_new_tokens
=
MAX_OUTPUT_LEN
,
do_sample
=
False
)
input_tokens
=
{
"input_ids"
:
torch
.
randint
(
1
,
1000
,
(
BATCH_SIZE
,
MAX_INPUT_LEN
),
device
=
"cuda"
),
"attention_mask"
:
torch
.
ones
((
BATCH_SIZE
,
MAX_INPUT_LEN
),
device
=
"cuda"
),
}
outputs
=
infer_engine
.
generate
(
input_tokens
,
**
generate_kwargs
)
assert
outputs
is
not
None
def
check_llama
(
rank
,
world_size
,
port
):
disable_existing_loggers
()
colossalai
.
launch
(
config
=
{},
rank
=
rank
,
world_size
=
world_size
,
host
=
"localhost"
,
port
=
port
,
backend
=
"nccl"
)
run_llama_test
()
@
pytest
.
mark
.
skipif
(
not
CUDA_SUPPORT
,
reason
=
"kv-cache manager engine requires cuda version to be higher than 11.5"
)
@
pytest
.
mark
.
dist
@
rerun_if_address_is_in_use
()
@
clear_cache_before_run
()
def
test_llama
():
spawn
(
check_llama
,
TPSIZE
)
if
__name__
==
"__main__"
:
test_llama
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment