Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
013a4bed
Unverified
Commit
013a4bed
authored
Oct 04, 2023
by
Jianghai
Committed by
GitHub
Oct 04, 2023
Browse files
[inference]fix import bug and delete down useless init (#4830)
* fix import bug and release useless init * fix * fix * fix
parent
573f2705
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
121 additions
and
154 deletions
+121
-154
colossalai/inference/tensor_parallel/modeling/__init__.py
colossalai/inference/tensor_parallel/modeling/__init__.py
+0
-2
colossalai/inference/tensor_parallel/modeling/_utils.py
colossalai/inference/tensor_parallel/modeling/_utils.py
+58
-1
colossalai/inference/tensor_parallel/modeling/llama.py
colossalai/inference/tensor_parallel/modeling/llama.py
+5
-19
colossalai/inference/tensor_parallel/policies/chatglm2.py
colossalai/inference/tensor_parallel/policies/chatglm2.py
+18
-21
colossalai/inference/tensor_parallel/policies/llama.py
colossalai/inference/tensor_parallel/policies/llama.py
+15
-10
colossalai/kernel/triton/__init__.py
colossalai/kernel/triton/__init__.py
+6
-4
examples/inference/bench_llama.py
examples/inference/bench_llama.py
+0
-25
examples/inference/gptq_llama.py
examples/inference/gptq_llama.py
+19
-47
tests/test_infer/test_llama_infer.py
tests/test_infer/test_llama_infer.py
+0
-25
No files found.
colossalai/inference/tensor_parallel/modeling/__init__.py
View file @
013a4bed
import
_utils
from
.bloom
import
BloomInferenceForwards
from
.chatglm2
import
ChatGLM2InferenceForwards
from
.llama
import
LlamaInferenceForwards
...
...
colossalai/inference/tensor_parallel/modeling/_utils.py
View file @
013a4bed
"""
Utils for model inference
"""
import
os
import
torch
from
colossalai.kernel.triton.copy_kv_cache_dest
import
copy_kv_cache_to_dest
def
copy_kv_to_mem_cache
(
layer_id
,
key_buffer
,
value_buffer
,
context_mem_index
,
mem_manager
):
"""
This function copies the key and value cache to the memory cache
Args:
layer_id : id of current layer
key_buffer : key cache
value_buffer : value cache
context_mem_index : index of memory cache in kv cache manager
mem_manager : cache manager
"""
copy_kv_cache_to_dest
(
key_buffer
,
context_mem_index
,
mem_manager
.
key_buffer
[
layer_id
])
copy_kv_cache_to_dest
(
value_buffer
,
context_mem_index
,
mem_manager
.
value_buffer
[
layer_id
])
return
def
init_to_get_rotary
(
self
,
base
=
10000
,
use_elem
=
False
):
"""
This function initializes the rotary positional embedding, it is compatible for all models and is called in ShardFormer
Args:
self : Model that holds the rotary positional embedding
base : calculation arg
use_elem : activated when using chatglm-based models
"""
self
.
config
.
head_dim_
=
self
.
config
.
hidden_size
//
self
.
config
.
num_attention_heads
if
not
hasattr
(
self
.
config
,
"rope_scaling"
):
rope_scaling_factor
=
1.0
else
:
rope_scaling_factor
=
self
.
config
.
rope_scaling
.
factor
if
self
.
config
.
rope_scaling
is
not
None
else
1.0
if
hasattr
(
self
.
config
,
"max_sequence_length"
):
max_seq_len
=
self
.
config
.
max_sequence_length
elif
hasattr
(
self
.
config
,
"max_position_embeddings"
):
max_seq_len
=
self
.
config
.
max_position_embeddings
*
rope_scaling_factor
else
:
max_seq_len
=
2048
*
rope_scaling_factor
base
=
float
(
base
)
# NTK ref: https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/
ntk_alpha
=
float
(
os
.
environ
.
get
(
"INFER_NTK_ALPHA"
,
None
))
if
ntk_alpha
is
not
None
:
ntk_alpha
=
float
(
ntk_alpha
)
assert
ntk_alpha
>=
1
,
"NTK alpha must be greater than or equal to 1"
if
ntk_alpha
>
1
:
print
(
f
"Note: NTK enabled, alpha set to
{
ntk_alpha
}
"
)
max_seq_len
*=
ntk_alpha
base
=
base
*
(
ntk_alpha
**
(
self
.
head_dim_
/
(
self
.
head_dim_
-
2
)))
# Base change formula
n_elem
=
self
.
config
.
head_dim_
if
use_elem
:
n_elem
//=
2
inv_freq
=
1.0
/
(
base
**
(
torch
.
arange
(
0
,
n_elem
,
2
,
device
=
"cpu"
,
dtype
=
torch
.
float32
)
/
n_elem
))
t
=
torch
.
arange
(
max_seq_len
+
1024
*
64
,
device
=
"cpu"
,
dtype
=
torch
.
float32
)
/
rope_scaling_factor
freqs
=
torch
.
outer
(
t
,
inv_freq
)
self
.
_cos_cached
=
torch
.
cos
(
freqs
).
to
(
torch
.
float16
).
cuda
()
self
.
_sin_cached
=
torch
.
sin
(
freqs
).
to
(
torch
.
float16
).
cuda
()
colossalai/inference/tensor_parallel/modeling/llama.py
View file @
013a4bed
...
...
@@ -5,12 +5,9 @@ from transformers.modeling_outputs import BaseModelOutputWithPast
from
transformers.models.llama.modeling_llama
import
LlamaAttention
,
LlamaDecoderLayer
,
LlamaModel
,
LlamaRMSNorm
from
colossalai.inference.tensor_parallel.batch_infer_state
import
BatchInferState
from
colossalai.kernel.triton
import
(
copy_kv_cache_to_dest
,
llama_context_attn_fwd
,
rotary_embedding_fwd
,
token_attention_fwd
,
)
from
colossalai.kernel.triton
import
llama_context_attn_fwd
,
rotary_embedding_fwd
,
token_attention_fwd
from
._utils
import
copy_kv_to_mem_cache
try
:
from
vllm
import
layernorm_ops
,
pos_encoding_ops
...
...
@@ -46,12 +43,6 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
return
q_embed
,
k_embed
def
_copy_kv_to_mem_cache
(
layer_id
,
key_buffer
,
value_buffer
,
context_mem_index
,
mem_manager
):
copy_kv_cache_to_dest
(
key_buffer
,
context_mem_index
,
mem_manager
.
key_buffer
[
layer_id
])
copy_kv_cache_to_dest
(
value_buffer
,
context_mem_index
,
mem_manager
.
value_buffer
[
layer_id
])
return
class
LlamaInferenceForwards
:
"""
This class holds forwards for llama inference.
...
...
@@ -285,11 +276,6 @@ class LlamaInferenceForwards:
rotary_embedding_fwd
(
query_states
.
view
(
-
1
,
self
.
num_heads
,
self
.
head_dim
),
cos
,
sin
)
rotary_embedding_fwd
(
key_states
.
view
(
-
1
,
self
.
num_heads
,
self
.
head_dim
),
cos
,
sin
)
def
_copy_kv_to_mem_cache
(
layer_id
,
key_buffer
,
value_buffer
,
context_mem_index
,
mem_manager
):
copy_kv_cache_to_dest
(
key_buffer
,
context_mem_index
,
mem_manager
.
key_buffer
[
layer_id
])
copy_kv_cache_to_dest
(
value_buffer
,
context_mem_index
,
mem_manager
.
value_buffer
[
layer_id
])
return
query_states
=
query_states
.
reshape
(
-
1
,
self
.
num_heads
,
self
.
head_dim
)
key_states
=
key_states
.
reshape
(
-
1
,
self
.
num_heads
,
self
.
head_dim
)
value_states
=
value_states
.
reshape
(
-
1
,
self
.
num_heads
,
self
.
head_dim
)
...
...
@@ -298,7 +284,7 @@ class LlamaInferenceForwards:
# first token generation
# copy key and value calculated in current step to memory manager
_
copy_kv_to_mem_cache
(
copy_kv_to_mem_cache
(
infer_state
.
decode_layer_id
,
key_states
,
value_states
,
...
...
@@ -331,7 +317,7 @@ class LlamaInferenceForwards:
else
:
# if decode is not contiguous, use triton kernel to copy key and value cache
# k, v shape: [batch_size, num_heads, head_dim/embed_size_per_head
_
copy_kv_to_mem_cache
(
copy_kv_to_mem_cache
(
infer_state
.
decode_layer_id
,
key_states
,
value_states
,
...
...
colossalai/inference/tensor_parallel/policies/chatglm2.py
View file @
013a4bed
from
functools
import
partial
import
torch
from
colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm
import
(
ChatGLMForConditionalGeneration
,
ChatGLMModel
,
...
...
@@ -9,13 +7,14 @@ from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import (
GLMTransformer
,
SelfAttention
,
)
# import colossalai
from
colossalai.shardformer.policies.chatglm2
import
ChatGLMModelPolicy
from
..modeling.chatglm2
import
ChatGLM2InferenceForwards
,
_init_to_get_rotary
from
..modeling._utils
import
init_to_get_rotary
from
..modeling.chatglm2
import
ChatGLM2InferenceForwards
try
:
from
colossalai.kernel.triton.rms_norm
import
rmsnorm_forward
HAS_TRITON_RMSNORM
=
True
except
:
print
(
"you should install triton from https://github.com/openai/triton"
)
...
...
@@ -23,7 +22,6 @@ except:
class
ChatGLM2InferPolicy
(
ChatGLMModelPolicy
):
def
__init__
(
self
)
->
None
:
super
().
__init__
()
...
...
@@ -32,45 +30,44 @@ class ChatGLM2InferPolicy(ChatGLMModelPolicy):
self
.
shard_config
.
_infer
()
model_infer_forward
=
ChatGLM2InferenceForwards
.
chatglm_model_forward
method_replacement
=
{
'
forward
'
:
model_infer_forward
}
method_replacement
=
{
"
forward
"
:
model_infer_forward
}
self
.
append_or_create_method_replacement
(
description
=
method_replacement
,
policy
=
policy
,
target_key
=
ChatGLMModel
)
encoder_infer_forward
=
ChatGLM2InferenceForwards
.
chatglm_encoder_forward
method_replacement
=
{
'
forward
'
:
encoder_infer_forward
}
self
.
append_or_create_method_replacement
(
description
=
method_replacement
,
policy
=
policy
,
target_key
=
GLMTransformer
)
method_replacement
=
{
"
forward
"
:
encoder_infer_forward
}
self
.
append_or_create_method_replacement
(
description
=
method_replacement
,
policy
=
policy
,
target_key
=
GLMTransformer
)
encoder_layer_infer_forward
=
ChatGLM2InferenceForwards
.
chatglm_glmblock_forward
method_replacement
=
{
'
forward
'
:
encoder_layer_infer_forward
}
method_replacement
=
{
"
forward
"
:
encoder_layer_infer_forward
}
self
.
append_or_create_method_replacement
(
description
=
method_replacement
,
policy
=
policy
,
target_key
=
GLMBlock
)
attn_infer_forward
=
ChatGLM2InferenceForwards
.
chatglm_flash_attn_kvcache_forward
method_replacement
=
{
'
forward
'
:
attn_infer_forward
}
self
.
append_or_create_method_replacement
(
description
=
method_replacement
,
policy
=
policy
,
target_key
=
SelfAttention
)
method_replacement
=
{
"
forward
"
:
attn_infer_forward
}
self
.
append_or_create_method_replacement
(
description
=
method_replacement
,
policy
=
policy
,
target_key
=
SelfAttention
)
# for rmsnorm and others, we need to check the shape
return
policy
def
postprocess
(
self
):
_
init_to_get_rotary
(
self
.
model
)
init_to_get_rotary
(
self
.
model
)
return
self
.
model
class
ChatGLM2ForConditionalGenerationInferPolicy
(
ChatGLM2InferPolicy
):
def
__init__
(
self
)
->
None
:
super
().
__init__
()
def
module_policy
(
self
):
policy
=
super
().
module_policy
()
model_infer_forward
=
ChatGLM2InferenceForwards
.
chatglm_for_conditional_generation_forward
method_replacement
=
{
'
forward
'
:
partial
(
model_infer_forward
)}
self
.
append_or_create_method_replacement
(
description
=
method_replacement
,
policy
=
policy
,
target_key
=
ChatGLMForConditionalGeneration
)
method_replacement
=
{
"
forward
"
:
partial
(
model_infer_forward
)}
self
.
append_or_create_method_replacement
(
description
=
method_replacement
,
policy
=
policy
,
target_key
=
ChatGLMForConditionalGeneration
)
return
policy
def
postprocess
(
self
):
...
...
colossalai/inference/tensor_parallel/policies/llama.py
View file @
013a4bed
...
...
@@ -3,11 +3,12 @@ from functools import partial
import
torch
from
transformers.models.llama.modeling_llama
import
LlamaAttention
,
LlamaDecoderLayer
,
LlamaModel
,
LlamaRMSNorm
from
colossalai.shardformer.
layer
import
VocabParallelEmbedding1D
from
colossalai.shardformer.policies.base_policy
import
ModulePolicyDescription
,
Policy
,
SubModuleReplacementDescription
from
colossalai.shardformer.
policies.base_policy
import
ModulePolicyDescription
,
SubModuleReplacementDescription
# import colossalai
from
colossalai.shardformer.policies.llama
import
LlamaForCausalLMPolicy
from
..modeling._utils
import
init_to_get_rotary
from
..modeling.llama
import
LlamaInferenceForwards
,
get_llama_vllm_rmsnorm_forward
try
:
...
...
@@ -50,38 +51,38 @@ class LlamaModelInferPolicy(LlamaForCausalLMPolicy):
SubModuleReplacementDescription
(
suffix
=
"self_attn.q_proj"
,
target_module
=
ColCaiQuantLinear
,
kwargs
=
{
'
split_num
'
:
1
},
kwargs
=
{
"
split_num
"
:
1
},
),
SubModuleReplacementDescription
(
suffix
=
"self_attn.k_proj"
,
target_module
=
ColCaiQuantLinear
,
kwargs
=
{
'
split_num
'
:
1
},
kwargs
=
{
"
split_num
"
:
1
},
),
SubModuleReplacementDescription
(
suffix
=
"self_attn.v_proj"
,
target_module
=
ColCaiQuantLinear
,
kwargs
=
{
'
split_num
'
:
1
},
kwargs
=
{
"
split_num
"
:
1
},
),
SubModuleReplacementDescription
(
suffix
=
"self_attn.o_proj"
,
target_module
=
RowCaiQuantLinear
,
kwargs
=
{
'
split_num
'
:
1
},
kwargs
=
{
"
split_num
"
:
1
},
),
SubModuleReplacementDescription
(
suffix
=
"mlp.gate_proj"
,
target_module
=
ColCaiQuantLinear
,
kwargs
=
{
'
split_num
'
:
1
},
kwargs
=
{
"
split_num
"
:
1
},
),
SubModuleReplacementDescription
(
suffix
=
"mlp.up_proj"
,
target_module
=
ColCaiQuantLinear
,
kwargs
=
{
'
split_num
'
:
1
},
kwargs
=
{
"
split_num
"
:
1
},
),
SubModuleReplacementDescription
(
suffix
=
"mlp.down_proj"
,
target_module
=
RowCaiQuantLinear
,
kwargs
=
{
'
split_num
'
:
1
},
)
kwargs
=
{
"
split_num
"
:
1
},
)
,
],
)
...
...
@@ -117,3 +118,7 @@ class LlamaModelInferPolicy(LlamaForCausalLMPolicy):
)
return
policy
def
postprocess
(
self
):
init_to_get_rotary
(
self
.
model
.
model
)
return
self
.
model
colossalai/kernel/triton/__init__.py
View file @
013a4bed
...
...
@@ -3,6 +3,12 @@ try:
HAS_TRITON
=
True
except
ImportError
:
HAS_TRITON
=
False
print
(
"Triton is not installed. Please install Triton to use Triton kernels."
)
# There may exist import error even if we have triton installed.
if
HAS_TRITON
:
from
.context_attention
import
bloom_context_attn_fwd
,
llama_context_attn_fwd
from
.copy_kv_cache_dest
import
copy_kv_cache_to_dest
from
.fused_layernorm
import
layer_norm
...
...
@@ -23,7 +29,3 @@ try:
"token_attention_fwd"
,
"gptq_fused_linear_triton"
,
]
except
ImportError
:
HAS_TRITON
=
False
print
(
"Triton is not installed. Please install Triton to use Triton kernels."
)
examples/inference/bench_llama.py
View file @
013a4bed
...
...
@@ -15,30 +15,6 @@ from colossalai.testing import clear_cache_before_run, rerun_if_address_is_in_us
os
.
environ
[
"TRANSFORMERS_NO_ADVISORY_WARNINGS"
]
=
"true"
def
init_to_get_rotary
(
self
,
base
=
10000
):
self
.
config
.
head_dim_
=
self
.
config
.
hidden_size
//
self
.
config
.
num_attention_heads
if
not
hasattr
(
self
.
config
,
"rope_scaling"
):
rope_scaling_factor
=
1.0
else
:
rope_scaling_factor
=
self
.
config
.
rope_scaling
.
factor
if
self
.
config
.
rope_scaling
is
not
None
else
1.0
if
hasattr
(
self
.
config
,
"max_sequence_length"
):
max_seq_len
=
self
.
config
.
max_sequence_length
elif
hasattr
(
self
.
config
,
"max_position_embeddings"
):
max_seq_len
=
self
.
config
.
max_position_embeddings
*
rope_scaling_factor
else
:
max_seq_len
=
2048
*
rope_scaling_factor
base
=
float
(
base
)
inv_freq
=
1.0
/
(
base
**
(
torch
.
arange
(
0
,
self
.
config
.
head_dim_
,
2
,
device
=
"cpu"
,
dtype
=
torch
.
float32
)
/
self
.
config
.
head_dim_
)
)
t
=
torch
.
arange
(
max_seq_len
+
1024
*
64
,
device
=
"cpu"
,
dtype
=
torch
.
float32
)
/
rope_scaling_factor
freqs
=
torch
.
outer
(
t
,
inv_freq
)
self
.
_cos_cached
=
torch
.
cos
(
freqs
).
to
(
torch
.
float16
).
cuda
()
self
.
_sin_cached
=
torch
.
sin
(
freqs
).
to
(
torch
.
float16
).
cuda
()
return
def
print_perf_stats
(
latency_set
,
config
,
bs
,
warmup
=
3
):
# trim warmup queries
latency_set
=
list
(
latency_set
)
...
...
@@ -66,7 +42,6 @@ def run_llama_test(args):
tokenizer
=
LlamaTokenizer
.
from_pretrained
(
llama_model_path
)
tokenizer
.
pad_token_id
=
tokenizer
.
unk_token_id
model
=
LlamaForCausalLM
.
from_pretrained
(
llama_model_path
,
pad_token_id
=
tokenizer
.
eos_token_id
)
init_to_get_rotary
(
model
.
model
,
base
=
10000
)
model
=
model
.
half
()
model_config
=
model
.
config
...
...
examples/inference/gptq_llama.py
View file @
013a4bed
import
argparse
import
logging
import
os
import
time
import
torch
from
auto_gptq
import
AutoGPTQForCausalLM
,
BaseQuantizeConfig
from
auto_gptq.nn_modules.qlinear
import
GeneralQuantLinear
from
torch
import
distributed
as
dist
from
torch.profiler
import
ProfilerActivity
,
profile
,
record_function
from
transformers
import
AutoTokenizer
,
LlamaForCausalLM
,
LlamaTokenizer
,
TextGenerationPipeline
from
auto_gptq
import
AutoGPTQForCausalLM
from
transformers
import
LlamaTokenizer
import
colossalai
from
colossalai.gptq
import
CaiQuantLinear
from
colossalai.gptq.gptq_tp
import
replace_autogptq_linear
from
colossalai.inference.tensor_parallel.engine
import
TPInferEngine
from
colossalai.inference.tensor_parallel.modeling._utils
import
init_to_get_rotary
from
colossalai.logging
import
disable_existing_loggers
from
colossalai.shardformer
import
ShardConfig
from
colossalai.testing
import
clear_cache_before_run
,
rerun_if_address_is_in_use
,
spawn
os
.
environ
[
'TRANSFORMERS_NO_ADVISORY_WARNINGS'
]
=
'true'
def
init_to_get_rotary
(
self
,
base
=
10000
):
self
.
config
.
head_dim_
=
self
.
config
.
hidden_size
//
self
.
config
.
num_attention_heads
if
not
hasattr
(
self
.
config
,
"rope_scaling"
):
rope_scaling_factor
=
1.0
else
:
rope_scaling_factor
=
self
.
config
.
rope_scaling
.
factor
if
self
.
config
.
rope_scaling
is
not
None
else
1.0
if
hasattr
(
self
.
config
,
"max_sequence_length"
):
max_seq_len
=
self
.
config
.
max_sequence_length
elif
hasattr
(
self
.
config
,
"max_position_embeddings"
):
max_seq_len
=
self
.
config
.
max_position_embeddings
*
rope_scaling_factor
else
:
max_seq_len
=
2048
*
rope_scaling_factor
base
=
float
(
base
)
inv_freq
=
1.0
/
(
base
**
(
torch
.
arange
(
0
,
self
.
config
.
head_dim_
,
2
,
device
=
"cpu"
,
dtype
=
torch
.
float32
)
/
self
.
config
.
head_dim_
))
t
=
torch
.
arange
(
max_seq_len
+
1024
*
64
,
device
=
"cpu"
,
dtype
=
torch
.
float32
)
/
rope_scaling_factor
freqs
=
torch
.
outer
(
t
,
inv_freq
)
self
.
_cos_cached
=
torch
.
cos
(
freqs
).
to
(
torch
.
float16
).
cuda
()
self
.
_sin_cached
=
torch
.
sin
(
freqs
).
to
(
torch
.
float16
).
cuda
()
return
os
.
environ
[
"TRANSFORMERS_NO_ADVISORY_WARNINGS"
]
=
"true"
def
print_perf_stats
(
latency_set
,
config
,
bs
,
warmup
=
3
):
...
...
@@ -74,23 +46,23 @@ def run_llama_test(args):
tokenizer
.
pad_token_id
=
tokenizer
.
eos_token_id
# load quantized model to the first GPU
model
=
AutoGPTQForCausalLM
.
from_quantized
(
quantized_model_dir
,
device
=
torch
.
cuda
.
current_device
(),
inject_fused_attention
=
False
)
model
=
AutoGPTQForCausalLM
.
from_quantized
(
quantized_model_dir
,
device
=
torch
.
cuda
.
current_device
(),
inject_fused_attention
=
False
)
init_to_get_rotary
(
model
.
model
.
model
,
base
=
10000
)
model_config
=
model
.
config
shard_config
=
ShardConfig
(
enable_tensor_parallelism
=
True
if
args
.
tp_size
>
1
else
False
,
inference_only
=
True
,
inference_gptq
=
True
)
shard_config
=
ShardConfig
(
enable_tensor_parallelism
=
True
if
args
.
tp_size
>
1
else
False
,
inference_only
=
True
,
inference_gptq
=
True
)
infer_engine
=
TPInferEngine
(
model
,
shard_config
,
max_batch_size
,
max_input_len
,
max_output_len
)
generate_kwargs
=
dict
(
max_new_tokens
=
max_output_len
,
do_sample
=
False
)
input_tokens
=
{
"input_ids"
:
torch
.
randint
(
1
,
1000
,
(
max_batch_size
,
max_input_len
),
device
=
'
cuda
'
),
"attention_mask"
:
torch
.
ones
((
max_batch_size
,
max_input_len
),
device
=
'
cuda
'
)
"input_ids"
:
torch
.
randint
(
1
,
1000
,
(
max_batch_size
,
max_input_len
),
device
=
"
cuda
"
),
"attention_mask"
:
torch
.
ones
((
max_batch_size
,
max_input_len
),
device
=
"
cuda
"
),
}
iters
=
10
...
...
@@ -111,7 +83,7 @@ def run_llama_test(args):
def
check_llama
(
rank
,
world_size
,
port
,
args
):
disable_existing_loggers
()
colossalai
.
launch
(
config
=
{},
rank
=
rank
,
world_size
=
world_size
,
host
=
'
localhost
'
,
port
=
port
,
backend
=
'
nccl
'
)
colossalai
.
launch
(
config
=
{},
rank
=
rank
,
world_size
=
world_size
,
host
=
"
localhost
"
,
port
=
port
,
backend
=
"
nccl
"
)
run_llama_test
(
args
)
...
...
@@ -123,12 +95,12 @@ def test_llama(args):
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'
-p
'
,
'
--path
'
,
type
=
str
,
help
=
'
Model path
'
,
required
=
True
)
parser
.
add_argument
(
'
-q
'
,
'
--quantized_path
'
,
type
=
str
,
help
=
'
Model path
'
,
required
=
True
)
parser
.
add_argument
(
'
-tp
'
,
'
--tp_size
'
,
type
=
int
,
default
=
1
,
help
=
'
Tensor parallel size
'
)
parser
.
add_argument
(
'
-b
'
,
'
--batch_size
'
,
type
=
int
,
default
=
16
,
help
=
'
Maximum batch size
'
)
parser
.
add_argument
(
'
--input_len
'
,
type
=
int
,
default
=
1024
,
help
=
'
Maximum input length
'
)
parser
.
add_argument
(
'
--output_len
'
,
type
=
int
,
default
=
128
,
help
=
'
Maximum output length
'
)
parser
.
add_argument
(
"
-p
"
,
"
--path
"
,
type
=
str
,
help
=
"
Model path
"
,
required
=
True
)
parser
.
add_argument
(
"
-q
"
,
"
--quantized_path
"
,
type
=
str
,
help
=
"
Model path
"
,
required
=
True
)
parser
.
add_argument
(
"
-tp
"
,
"
--tp_size
"
,
type
=
int
,
default
=
1
,
help
=
"
Tensor parallel size
"
)
parser
.
add_argument
(
"
-b
"
,
"
--batch_size
"
,
type
=
int
,
default
=
16
,
help
=
"
Maximum batch size
"
)
parser
.
add_argument
(
"
--input_len
"
,
type
=
int
,
default
=
1024
,
help
=
"
Maximum input length
"
)
parser
.
add_argument
(
"
--output_len
"
,
type
=
int
,
default
=
128
,
help
=
"
Maximum output length
"
)
args
=
parser
.
parse_args
()
...
...
tests/test_infer/test_llama_infer.py
View file @
013a4bed
...
...
@@ -20,30 +20,6 @@ MAX_OUTPUT_LEN = 100
CUDA_SUPPORT
=
version
.
parse
(
torch
.
version
.
cuda
)
>
version
.
parse
(
"11.5"
)
def
init_to_get_rotary
(
self
,
base
=
10000
):
self
.
config
.
head_dim_
=
self
.
config
.
hidden_size
//
self
.
config
.
num_attention_heads
if
not
hasattr
(
self
.
config
,
"rope_scaling"
):
rope_scaling_factor
=
1.0
else
:
rope_scaling_factor
=
self
.
config
.
rope_scaling
.
factor
if
self
.
config
.
rope_scaling
is
not
None
else
1.0
if
hasattr
(
self
.
config
,
"max_sequence_length"
):
max_seq_len
=
self
.
config
.
max_sequence_length
elif
hasattr
(
self
.
config
,
"max_position_embeddings"
):
max_seq_len
=
self
.
config
.
max_position_embeddings
*
rope_scaling_factor
else
:
max_seq_len
=
2048
*
rope_scaling_factor
base
=
float
(
base
)
inv_freq
=
1.0
/
(
base
**
(
torch
.
arange
(
0
,
self
.
config
.
head_dim_
,
2
,
device
=
"cpu"
,
dtype
=
torch
.
float32
)
/
self
.
config
.
head_dim_
)
)
t
=
torch
.
arange
(
max_seq_len
+
1024
*
64
,
device
=
"cpu"
,
dtype
=
torch
.
float32
)
/
rope_scaling_factor
freqs
=
torch
.
outer
(
t
,
inv_freq
)
self
.
_cos_cached
=
torch
.
cos
(
freqs
).
to
(
torch
.
float16
).
cuda
()
self
.
_sin_cached
=
torch
.
sin
(
freqs
).
to
(
torch
.
float16
).
cuda
()
return
@
parameterize
(
"test_config"
,
[
...
...
@@ -56,7 +32,6 @@ def run_llama_test(test_config):
sub_model_zoo
=
model_zoo
.
get_sub_registry
(
"transformers_llama_for_casual_lm"
)
for
name
,
(
model_fn
,
data_gen_fn
,
_
,
_
,
_
)
in
sub_model_zoo
.
items
():
orig_model
=
model_fn
()
init_to_get_rotary
(
orig_model
.
model
,
base
=
10000
)
orig_model
=
orig_model
.
half
()
data
=
data_gen_fn
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment