Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
d1fcc0fa
Unverified
Commit
d1fcc0fa
authored
Oct 04, 2023
by
Xu Kai
Committed by
GitHub
Oct 04, 2023
Browse files
[infer] fix test bug (#4838)
* fix test bug * delete useless code * fix typo
parent
013a4bed
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
56 additions
and
51 deletions
+56
-51
colossalai/shardformer/modeling/chatglm2_6b/modeling_chatglm.py
...alai/shardformer/modeling/chatglm2_6b/modeling_chatglm.py
+1
-1
examples/inference/bench_llama.py
examples/inference/bench_llama.py
+0
-1
tests/test_infer/test_bloom_infer.py
tests/test_infer/test_bloom_infer.py
+16
-13
tests/test_infer/test_chatglm2_infer.py
tests/test_infer/test_chatglm2_infer.py
+21
-20
tests/test_infer/test_llama_infer.py
tests/test_infer/test_llama_infer.py
+17
-13
tests/test_infer_ops/triton/test_llama2_token_attn.py
tests/test_infer_ops/triton/test_llama2_token_attn.py
+1
-3
No files found.
colossalai/shardformer/modeling/chatglm2_6b/modeling_chatglm.py
View file @
d1fcc0fa
...
...
@@ -873,7 +873,7 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
self
.
rotary_pos_emb
=
RotaryEmbedding
(
rotary_dim
//
2
,
original_impl
=
config
.
original_rope
,
#
original_impl=config.original_rope,
# config has no attribute original_rope
device
=
device
,
dtype
=
config
.
torch_dtype
,
)
...
...
examples/inference/bench_llama.py
View file @
d1fcc0fa
...
...
@@ -43,7 +43,6 @@ def run_llama_test(args):
tokenizer
.
pad_token_id
=
tokenizer
.
unk_token_id
model
=
LlamaForCausalLM
.
from_pretrained
(
llama_model_path
,
pad_token_id
=
tokenizer
.
eos_token_id
)
model
=
model
.
half
()
model_config
=
model
.
config
shard_config
=
ShardConfig
(
enable_tensor_parallelism
=
True
if
args
.
tp_size
>
1
else
False
,
inference_only
=
True
)
...
...
tests/test_infer/test_bloom_infer.py
View file @
d1fcc0fa
import
pytest
import
torch
from
packaging
import
version
from
transformers
import
BloomForCausalLM
from
transformers.models.bloom.configuration_bloom
import
BloomConfig
import
colossalai
from
colossalai.inference.tensor_parallel
import
TPInferEngine
from
colossalai.logging
import
disable_existing_loggers
from
colossalai.shardformer
import
ShardConfig
from
colossalai.testing
import
clear_cache_before_run
,
parameterize
,
rerun_if_address_is_in_use
,
spawn
from
tests.kit.model_zoo
import
model_zoo
TP_SIZE
=
2
MAX_BATCH_SIZE
=
4
...
...
@@ -26,21 +27,23 @@ CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.5")
],
)
def
run
(
test_config
):
sub_model_zoo
=
model_zoo
.
get_sub_registry
(
"transformers_bloom_for_causal_lm"
)
for
name
,
(
model_fn
,
data_gen_fn
,
_
,
_
,
_
)
in
sub_model_zoo
.
items
():
orig_model
=
model_fn
()
orig_model
=
orig_model
.
half
()
data
=
data_gen_fn
()
bloom_config
=
BloomConfig
(
num_hidden_layers
=
2
,
bos_token_id
=
0
,
eos_token_id
=
1
,
vocab_size
=
1200
,
hidden_size
=
1024
)
model
=
BloomForCausalLM
(
bloom_config
)
model
=
model
.
half
()
shard_config
=
ShardConfig
(
enable_tensor_parallelism
=
True
if
test_config
[
"tp_size"
]
>
1
else
False
,
inference_only
=
True
)
infer_engine
=
TPInferEngine
(
orig_model
,
shard_config
,
MAX_BATCH_SIZE
,
MAX_INPUT_LEN
,
MAX_OUTPUT_LEN
)
shard_config
=
ShardConfig
(
enable_tensor_parallelism
=
True
if
test_config
[
"tp_size"
]
>
1
else
False
,
inference_only
=
True
)
infer_engine
=
TPInferEngine
(
model
,
shard_config
,
MAX_BATCH_SIZE
,
MAX_INPUT_LEN
,
MAX_OUTPUT_LEN
)
generate_kwargs
=
dict
(
max_new_tokens
=
MAX_OUTPUT_LEN
,
do_sample
=
False
)
generate_kwargs
=
dict
(
do_sample
=
False
)
outputs
=
infer_engine
.
generate
(
data
,
**
generate_kwargs
)
input_tokens
=
{
"input_ids"
:
torch
.
randint
(
1
,
1000
,
(
MAX_BATCH_SIZE
,
MAX_INPUT_LEN
),
device
=
"cuda"
),
"attention_mask"
:
torch
.
ones
((
MAX_BATCH_SIZE
,
MAX_INPUT_LEN
),
device
=
"cuda"
),
}
outputs
=
infer_engine
.
generate
(
input_tokens
,
**
generate_kwargs
)
assert
outputs
is
not
None
assert
outputs
is
not
None
def
check_bloom
(
rank
,
world_size
,
port
):
...
...
tests/test_infer/test_chatglm2_infer.py
View file @
d1fcc0fa
...
...
@@ -2,17 +2,15 @@ import os
import
pytest
import
torch
import
torch.distributed
as
dist
from
packaging
import
version
from
transformers
import
AutoTokenizer
import
colossalai
from
colossalai.inference.tensor_parallel.engine
import
TPInferEngine
from
colossalai.logging
import
disable_existing_loggers
from
colossalai.shardformer
import
ShardConfig
from
colossalai.shardformer.modeling.chatglm2_6b.configuration_chatglm
import
ChatGLMConfig
from
colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm
import
ChatGLMForConditionalGeneration
from
colossalai.testing
import
clear_cache_before_run
,
parameterize
,
rerun_if_address_is_in_use
,
spawn
from
tests.kit.model_zoo.transformers.chatglm2
import
infer_config
os
.
environ
[
"TRANSFORMERS_NO_ADVISORY_WARNINGS"
]
=
"true"
TPSIZE
=
1
...
...
@@ -31,28 +29,31 @@ CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.5")
],
)
def
run_chatglm2_test
(
test_config
):
tokenizer
=
AutoTokenizer
.
from_pretrained
(
"THUDM/chatglm2-6b"
,
trust_remote_code
=
True
)
# pad_token_id = 0
model_fn
=
lambda
:
ChatGLMForConditionalGeneration
(
infer_config
,
empty_init
=
False
)
orig_model
=
model_fn
()
orig_model
=
orig_model
.
half
()
text
=
[
"how is the weather today?"
]
input_ids
=
tokenizer
.
batch_encode_plus
(
text
,
return_tensors
=
"pt"
,
padding
=
True
)
chatglm_config
=
ChatGLMConfig
(
num_layers
=
2
,
vocab_size
=
1200
,
use_cache
=
True
,
multi_query_attention
=
True
,
multi_query_group_num
=
2
,
num_attention_heads
=
8
,
hidden_size
=
1024
,
)
model
=
ChatGLMForConditionalGeneration
(
chatglm_config
)
model
=
model
.
half
()
shard_config
=
ShardConfig
(
enable_tensor_parallelism
=
True
if
test_config
[
"tp_size"
]
>
1
else
False
,
inference_only
=
True
)
infer_engine
=
TPInferEngine
(
orig_model
,
shard_config
,
BATCH_SIZE
,
MAX_INPUT_LEN
,
MAX_OUTPUT_LEN
)
infer_engine
=
TPInferEngine
(
model
,
shard_config
,
BATCH_SIZE
,
MAX_INPUT_LEN
,
MAX_OUTPUT_LEN
)
generate_kwargs
=
dict
(
max_new_tokens
=
MAX_OUTPUT_LEN
,
do_sample
=
False
)
outputs
=
infer_engine
.
generate
(
input_ids
,
**
generate_kwargs
)
assert
outputs
is
not
None
# print("outputs.shape: ", outputs[0].shape)
# print("outputs: ", outputs[0])
if
not
dist
.
is_initialized
()
or
dist
.
get_rank
()
==
0
:
for
o
in
outputs
:
output_text
=
tokenizer
.
decode
(
o
)
print
(
output_text
)
input_tokens
=
{
"input_ids"
:
torch
.
randint
(
1
,
1000
,
(
BATCH_SIZE
,
MAX_INPUT_LEN
),
device
=
"cuda"
),
"attention_mask"
:
torch
.
ones
((
BATCH_SIZE
,
MAX_INPUT_LEN
),
device
=
"cuda"
),
}
outputs
=
infer_engine
.
generate
(
input_tokens
,
**
generate_kwargs
)
assert
outputs
is
not
None
def
check_chatglm2
(
rank
,
world_size
,
port
):
...
...
tests/test_infer/test_llama_infer.py
View file @
d1fcc0fa
...
...
@@ -3,13 +3,14 @@ import os
import
pytest
import
torch
from
packaging
import
version
from
transformers
import
LlamaForCausalLM
from
transformers.models.llama.configuration_llama
import
LlamaConfig
import
colossalai
from
colossalai.inference.tensor_parallel.engine
import
TPInferEngine
from
colossalai.logging
import
disable_existing_loggers
from
colossalai.shardformer
import
ShardConfig
from
colossalai.testing
import
clear_cache_before_run
,
parameterize
,
rerun_if_address_is_in_use
,
spawn
from
tests.kit.model_zoo
import
model_zoo
os
.
environ
[
"TRANSFORMERS_NO_ADVISORY_WARNINGS"
]
=
"true"
TPSIZE
=
2
...
...
@@ -29,21 +30,24 @@ CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.5")
],
)
def
run_llama_test
(
test_config
):
sub_model_zoo
=
model_zoo
.
get_sub_registry
(
"transformers_llama_for_casual_lm"
)
for
name
,
(
model_fn
,
data_gen_fn
,
_
,
_
,
_
)
in
sub_model_zoo
.
items
():
orig_model
=
model_fn
()
orig_model
=
orig_model
.
half
()
data
=
data_gen_fn
()
llama_config
=
LlamaConfig
(
num_hidden_layers
=
2
,
bos_token_id
=
0
,
eos_token_id
=
1
,
vocab_size
=
1200
,
hidden_size
=
1024
)
model
=
LlamaForCausalLM
(
llama_config
)
model
=
model
.
half
()
shard_config
=
ShardConfig
(
enable_tensor_parallelism
=
True
if
test_config
[
"tp_size"
]
>
1
else
False
,
inference_only
=
True
)
infer_engine
=
TPInferEngine
(
orig_model
,
shard_config
,
BATCH_SIZE
,
MAX_INPUT_LEN
,
MAX_OUTPUT_LEN
)
shard_config
=
ShardConfig
(
enable_tensor_parallelism
=
True
if
test_config
[
"tp_size"
]
>
1
else
False
,
inference_only
=
True
)
infer_engine
=
TPInferEngine
(
model
,
shard_config
,
BATCH_SIZE
,
MAX_INPUT_LEN
,
MAX_OUTPUT_LEN
)
init_to_get_rotary
(
model
.
model
,
base
=
10000
)
generate_kwargs
=
dict
(
max_new_tokens
=
MAX_OUTPUT_LEN
,
do_sample
=
False
)
generate_kwargs
=
dict
(
do_sample
=
False
)
outputs
=
infer_engine
.
generate
(
data
,
**
generate_kwargs
)
input_tokens
=
{
"input_ids"
:
torch
.
randint
(
1
,
1000
,
(
BATCH_SIZE
,
MAX_INPUT_LEN
),
device
=
"cuda"
),
"attention_mask"
:
torch
.
ones
((
BATCH_SIZE
,
MAX_INPUT_LEN
),
device
=
"cuda"
),
}
outputs
=
infer_engine
.
generate
(
input_tokens
,
**
generate_kwargs
)
assert
outputs
is
not
None
assert
outputs
is
not
None
def
check_llama
(
rank
,
world_size
,
port
):
...
...
tests/test_infer_ops/triton/test_llama2_token_attn.py
View file @
d1fcc0fa
...
...
@@ -38,9 +38,7 @@ def test():
q
=
torch
.
empty
((
Z
,
head_num
,
head_dim
),
dtype
=
dtype
,
device
=
"cuda"
).
normal_
(
mean
=
0.1
,
std
=
0.2
)
k
=
torch
.
empty
((
Z
*
seq_len
,
head_num
,
head_dim
),
dtype
=
dtype
,
device
=
"cuda"
).
normal_
(
mean
=
0.4
,
std
=
0.2
)
v
=
torch
.
empty
((
Z
*
seq_len
,
head_num
,
head_dim
),
dtype
=
dtype
,
device
=
"cuda"
).
normal_
(
mean
=
0.3
,
std
=
0.2
)
o
=
torch
.
empty_like
()
# o = torch.empty((Z, head_num, head_dim), dtype=dtype, device="cuda").normal_(mean=0.3, std=0.2)
o
=
torch
.
empty
((
Z
,
head_num
,
head_dim
),
dtype
=
dtype
,
device
=
"cuda"
)
max_kv_cache_len
=
seq_len
kv_cache_start_loc
=
torch
.
zeros
((
Z
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
kv_cache_loc
=
torch
.
zeros
((
Z
,
seq_len
),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment