Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
d69cd2eb
Unverified
Commit
d69cd2eb
authored
Jan 16, 2024
by
Frank Lee
Committed by
GitHub
Jan 16, 2024
Browse files
[workflow] fixed oom tests (#5275)
* [workflow] fixed oom tests * polish * polish * polish
parent
04244aaa
Changes
19
Hide whitespace changes
Inline
Side-by-side
Showing
19 changed files
with
50 additions
and
582 deletions
+50
-582
.github/workflows/build_on_pr.yml
.github/workflows/build_on_pr.yml
+0
-2
tests/kit/model_zoo/registry.py
tests/kit/model_zoo/registry.py
+5
-2
tests/kit/model_zoo/transformers/gptj.py
tests/kit/model_zoo/transformers/gptj.py
+3
-0
tests/test_booster/test_plugin/test_gemini_plugin.py
tests/test_booster/test_plugin/test_gemini_plugin.py
+8
-1
tests/test_booster/test_plugin/test_low_level_zero_plugin.py
tests/test_booster/test_plugin/test_low_level_zero_plugin.py
+3
-2
tests/test_booster/test_plugin/test_torch_ddp_plugin.py
tests/test_booster/test_plugin/test_torch_ddp_plugin.py
+3
-2
tests/test_booster/test_plugin/test_torch_fsdp_plugin.py
tests/test_booster/test_plugin/test_torch_fsdp_plugin.py
+15
-3
tests/test_checkpoint_io/test_hybrid_parallel_plugin_checkpoint_io.py
...heckpoint_io/test_hybrid_parallel_plugin_checkpoint_io.py
+5
-1
tests/test_infer_ops/triton/kernel_utils.py
tests/test_infer_ops/triton/kernel_utils.py
+0
-27
tests/test_infer_ops/triton/test_bloom_context_attention.py
tests/test_infer_ops/triton/test_bloom_context_attention.py
+0
-52
tests/test_infer_ops/triton/test_copy_kv_dest.py
tests/test_infer_ops/triton/test_copy_kv_dest.py
+0
-39
tests/test_infer_ops/triton/test_layernorm_triton.py
tests/test_infer_ops/triton/test_layernorm_triton.py
+0
-43
tests/test_infer_ops/triton/test_llama_act_combine.py
tests/test_infer_ops/triton/test_llama_act_combine.py
+0
-56
tests/test_infer_ops/triton/test_llama_context_attention.py
tests/test_infer_ops/triton/test_llama_context_attention.py
+0
-50
tests/test_infer_ops/triton/test_self_attention_nonfusion.py
tests/test_infer_ops/triton/test_self_attention_nonfusion.py
+0
-143
tests/test_infer_ops/triton/test_softmax.py
tests/test_infer_ops/triton/test_softmax.py
+0
-36
tests/test_infer_ops/triton/test_token_attn_fwd.py
tests/test_infer_ops/triton/test_token_attn_fwd.py
+0
-72
tests/test_infer_ops/triton/test_token_softmax.py
tests/test_infer_ops/triton/test_token_softmax.py
+0
-48
tests/test_lazy/test_models.py
tests/test_lazy/test_models.py
+8
-3
No files found.
.github/workflows/build_on_pr.yml
View file @
d69cd2eb
...
...
@@ -160,9 +160,7 @@ jobs:
--ignore tests/test_gptq \
--ignore tests/test_infer_ops \
--ignore tests/test_legacy \
--ignore tests/test_moe \
--ignore tests/test_smoothquant \
--ignore tests/test_checkpoint_io \
tests/
env
:
LD_LIBRARY_PATH
:
/github/home/.tensornvme/lib:/usr/local/nvidia/lib:/usr/local/nvidia/lib64
...
...
tests/kit/model_zoo/registry.py
View file @
d69cd2eb
...
...
@@ -61,7 +61,9 @@ class ModelZooRegistry(dict):
"""
self
[
name
]
=
(
model_fn
,
data_gen_fn
,
output_transform_fn
,
loss_fn
,
model_attribute
)
def
get_sub_registry
(
self
,
keyword
:
Union
[
str
,
List
[
str
]],
exclude
:
Union
[
str
,
List
[
str
]]
=
None
):
def
get_sub_registry
(
self
,
keyword
:
Union
[
str
,
List
[
str
]],
exclude
:
Union
[
str
,
List
[
str
]]
=
None
,
allow_empty
:
bool
=
False
):
"""
Get a sub registry with models that contain the keyword.
...
...
@@ -95,7 +97,8 @@ class ModelZooRegistry(dict):
if
not
should_exclude
:
new_dict
[
k
]
=
v
assert
len
(
new_dict
)
>
0
,
f
"No model found with keyword
{
keyword
}
"
if
not
allow_empty
:
assert
len
(
new_dict
)
>
0
,
f
"No model found with keyword
{
keyword
}
"
return
new_dict
...
...
tests/kit/model_zoo/transformers/gptj.py
View file @
d69cd2eb
...
...
@@ -63,6 +63,9 @@ config = transformers.GPTJConfig(
n_layer
=
2
,
n_head
=
4
,
vocab_size
=
50258
,
n_embd
=
256
,
hidden_size
=
256
,
n_positions
=
512
,
attn_pdrop
=
0
,
embd_pdrop
=
0
,
resid_pdrop
=
0
,
...
...
tests/test_booster/test_plugin/test_gemini_plugin.py
View file @
d69cd2eb
...
...
@@ -12,7 +12,13 @@ from colossalai.fx import is_compatible_with_meta
from
colossalai.lazy.lazy_init
import
LazyInitContext
from
colossalai.nn.optimizer
import
HybridAdam
from
colossalai.tensor.colo_parameter
import
ColoParameter
from
colossalai.testing
import
clear_cache_before_run
,
parameterize
,
rerun_if_address_is_in_use
,
spawn
from
colossalai.testing
import
(
clear_cache_before_run
,
parameterize
,
rerun_if_address_is_in_use
,
skip_if_not_enough_gpus
,
spawn
,
)
from
tests.kit.model_zoo
import
COMMON_MODELS
,
IS_FAST_TEST
,
model_zoo
...
...
@@ -172,6 +178,7 @@ def test_gemini_plugin(early_stop: bool = True):
@
pytest
.
mark
.
largedist
@
skip_if_not_enough_gpus
(
8
)
@
rerun_if_address_is_in_use
()
def
test_gemini_plugin_3d
(
early_stop
:
bool
=
True
):
spawn
(
run_dist
,
8
,
early_stop
=
early_stop
)
...
...
tests/test_booster/test_plugin/test_low_level_zero_plugin.py
View file @
d69cd2eb
...
...
@@ -10,8 +10,8 @@ from colossalai.booster import Booster
from
colossalai.booster.plugin
import
LowLevelZeroPlugin
# from colossalai.nn.optimizer import HybridAdam
from
colossalai.testing
import
parameterize
,
rerun_if_address_is_in_use
,
spawn
from
tests.kit.model_zoo
import
model_zoo
,
IS_FAST_TEST
,
COMMON_MODELS
from
colossalai.testing
import
clear_cache_before_run
,
parameterize
,
rerun_if_address_is_in_use
,
spawn
from
tests.kit.model_zoo
import
COMMON_MODELS
,
IS_FAST_TEST
,
model_zoo
# These models are not compatible with AMP
_AMP_ERR_MODELS
=
[
"timm_convit"
,
"deepfm_interactionarch"
]
...
...
@@ -21,6 +21,7 @@ _LOW_LEVEL_ZERO_ERR_MODELS = ["dlrm_interactionarch"]
_STUCK_MODELS
=
[
"transformers_albert_for_multiple_choice"
]
@
clear_cache_before_run
()
def
run_fn
(
stage
,
model_fn
,
data_gen_fn
,
output_transform_fn
)
->
Optional
[
str
]:
device
=
device_utils
.
get_current_device
()
try
:
...
...
tests/test_booster/test_plugin/test_torch_ddp_plugin.py
View file @
d69cd2eb
...
...
@@ -10,10 +10,11 @@ import colossalai
from
colossalai.booster
import
Booster
from
colossalai.booster.plugin
import
TorchDDPPlugin
from
colossalai.interface
import
OptimizerWrapper
from
colossalai.testing
import
rerun_if_address_is_in_use
,
spawn
from
tests.kit.model_zoo
import
model_zoo
,
IS_FAST_TEST
,
COMMON_MODELS
from
colossalai.testing
import
clear_cache_before_run
,
rerun_if_address_is_in_use
,
spawn
from
tests.kit.model_zoo
import
COMMON_MODELS
,
IS_FAST_TEST
,
model_zoo
@
clear_cache_before_run
()
def
run_fn
(
model_fn
,
data_gen_fn
,
output_transform_fn
):
plugin
=
TorchDDPPlugin
()
booster
=
Booster
(
plugin
=
plugin
)
...
...
tests/test_booster/test_plugin/test_torch_fsdp_plugin.py
View file @
d69cd2eb
...
...
@@ -11,11 +11,12 @@ if version.parse(torch.__version__) >= version.parse("1.12.0"):
from
colossalai.booster.plugin
import
TorchFSDPPlugin
from
colossalai.interface
import
OptimizerWrapper
from
colossalai.testing
import
rerun_if_address_is_in_use
,
spawn
from
tests.kit.model_zoo
import
model_zoo
,
IS_FAST_TEST
,
COMMON_MODELS
from
colossalai.testing
import
clear_cache_before_run
,
rerun_if_address_is_in_use
,
spawn
from
tests.kit.model_zoo
import
COMMON_MODELS
,
IS_FAST_TEST
,
model_zoo
# test basic fsdp function
@
clear_cache_before_run
()
def
run_fn
(
model_fn
,
data_gen_fn
,
output_transform_fn
):
plugin
=
TorchFSDPPlugin
()
booster
=
Booster
(
plugin
=
plugin
)
...
...
@@ -40,12 +41,18 @@ def run_fn(model_fn, data_gen_fn, output_transform_fn):
optimizer
.
clip_grad_by_norm
(
1.0
)
optimizer
.
step
()
del
model
del
optimizer
del
criterion
del
booster
del
plugin
def
check_torch_fsdp_plugin
():
if
IS_FAST_TEST
:
registry
=
model_zoo
.
get_sub_registry
(
COMMON_MODELS
)
else
:
registry
=
model_zoo
registry
=
model_zoo
.
get_sub_registry
(
"transformers_gptj"
)
for
name
,
(
model_fn
,
data_gen_fn
,
output_transform_fn
,
_
,
_
)
in
registry
.
items
():
if
any
(
...
...
@@ -59,6 +66,7 @@ def check_torch_fsdp_plugin():
]
):
continue
print
(
name
)
run_fn
(
model_fn
,
data_gen_fn
,
output_transform_fn
)
torch
.
cuda
.
empty_cache
()
...
...
@@ -73,3 +81,7 @@ def run_dist(rank, world_size, port):
@
rerun_if_address_is_in_use
()
def
test_torch_fsdp_plugin
():
spawn
(
run_dist
,
2
)
if
__name__
==
"__main__"
:
test_torch_fsdp_plugin
()
tests/test_checkpoint_io/test_hybrid_parallel_plugin_checkpoint_io.py
View file @
d69cd2eb
...
...
@@ -38,11 +38,11 @@ else:
]
@
clear_cache_before_run
()
@
parameterize
(
"shard"
,
[
True
,
False
])
@
parameterize
(
"model_name"
,
[
"transformers_llama_for_casual_lm"
])
@
parameterize
(
"size_per_shard"
,
[
32
])
@
parameterize
(
"test_config"
,
TEST_CONFIGS
)
@
clear_cache_before_run
()
def
exam_state_dict
(
shard
:
bool
,
model_name
:
str
,
size_per_shard
:
int
,
test_config
:
dict
):
(
model_fn
,
data_gen_fn
,
output_transform_fn
,
loss_fn
,
_
)
=
next
(
iter
(
model_zoo
.
get_sub_registry
(
model_name
).
values
())
...
...
@@ -145,3 +145,7 @@ def run_dist(rank, world_size, port):
@
rerun_if_address_is_in_use
()
def
test_hybrid_ckpIO
(
world_size
):
spawn
(
run_dist
,
world_size
)
if
__name__
==
"__main__"
:
test_hybrid_ckpIO
(
4
)
tests/test_infer_ops/triton/kernel_utils.py
deleted
100644 → 0
View file @
04244aaa
import
math
import
torch
from
torch.nn
import
functional
as
F
def
torch_context_attention
(
xq
,
xk
,
xv
,
bs
,
seqlen
,
num_head
,
head_dim
):
"""
adepted from https://github.com/ModelTC/lightllm/blob/main/lightllm/models/bloom/triton_kernel/context_flashattention_nopad.py#L253
"""
xq
=
xq
.
view
(
bs
,
seqlen
,
num_head
,
head_dim
)
xk
=
xk
.
view
(
bs
,
seqlen
,
num_head
,
head_dim
)
xv
=
xv
.
view
(
bs
,
seqlen
,
num_head
,
head_dim
)
mask
=
torch
.
tril
(
torch
.
ones
(
seqlen
,
seqlen
),
diagonal
=
0
).
unsqueeze
(
0
).
unsqueeze
(
0
).
cuda
()
mask
[
mask
==
0.0
]
=
-
100000000.0
mask
=
mask
.
repeat
(
bs
,
num_head
,
1
,
1
)
keys
=
xk
values
=
xv
xq
=
xq
.
transpose
(
1
,
2
)
keys
=
keys
.
transpose
(
1
,
2
)
values
=
values
.
transpose
(
1
,
2
)
sm_scale
=
1
/
math
.
sqrt
(
head_dim
)
scores
=
torch
.
matmul
(
xq
,
keys
.
transpose
(
2
,
3
))
*
sm_scale
scores
=
F
.
softmax
(
scores
.
float
()
+
mask
,
dim
=-
1
).
to
(
dtype
=
torch
.
float16
)
output
=
torch
.
matmul
(
scores
,
values
).
transpose
(
1
,
2
).
contiguous
().
reshape
(
-
1
,
num_head
,
head_dim
)
return
output
tests/test_infer_ops/triton/test_bloom_context_attention.py
deleted
100644 → 0
View file @
04244aaa
import
pytest
import
torch
from
packaging
import
version
try
:
pass
from
colossalai.kernel.triton
import
bloom_context_attn_fwd
from
tests.test_infer_ops.triton.kernel_utils
import
torch_context_attention
HAS_TRITON
=
True
except
ImportError
:
HAS_TRITON
=
False
print
(
"please install triton from https://github.com/openai/triton"
)
TRITON_CUDA_SUPPORT
=
version
.
parse
(
torch
.
version
.
cuda
)
>
version
.
parse
(
"11.4"
)
@
pytest
.
mark
.
skipif
(
not
TRITON_CUDA_SUPPORT
or
not
HAS_TRITON
,
reason
=
"triton requires cuda version to be higher than 11.4"
)
def
test_bloom_context_attention
():
bs
=
4
head_num
=
8
seq_len
=
1024
head_dim
=
64
query
=
torch
.
randn
((
bs
*
seq_len
,
head_num
,
head_dim
),
dtype
=
torch
.
float16
,
device
=
"cuda"
)
k
=
torch
.
randn
((
bs
*
seq_len
,
head_num
,
head_dim
),
dtype
=
torch
.
float16
,
device
=
"cuda"
)
v
=
torch
.
randn
((
bs
*
seq_len
,
head_num
,
head_dim
),
dtype
=
torch
.
float16
,
device
=
"cuda"
)
max_input_len
=
seq_len
b_start
=
torch
.
zeros
((
bs
,),
device
=
"cuda"
,
dtype
=
torch
.
int32
)
b_len
=
torch
.
zeros
((
bs
,),
device
=
"cuda"
,
dtype
=
torch
.
int32
)
for
i
in
range
(
bs
):
b_start
[
i
]
=
i
*
seq_len
b_len
[
i
]
=
seq_len
o
=
torch
.
randn
((
bs
*
seq_len
,
head_num
,
head_dim
),
dtype
=
torch
.
float16
,
device
=
"cuda"
)
alibi
=
torch
.
zeros
((
head_num
,),
dtype
=
torch
.
float32
,
device
=
"cuda"
)
bloom_context_attn_fwd
(
query
.
clone
(),
k
.
clone
(),
v
.
clone
(),
o
,
b_start
,
b_len
,
max_input_len
,
alibi
)
torch_out
=
torch_context_attention
(
query
.
clone
(),
k
.
clone
(),
v
.
clone
(),
bs
,
seq_len
,
head_num
,
head_dim
)
assert
torch
.
allclose
(
torch_out
.
cpu
(),
o
.
cpu
(),
rtol
=
1e-3
,
atol
=
1e-2
),
"outputs from triton and torch are not matched"
if
__name__
==
"__main__"
:
test_bloom_context_attention
()
tests/test_infer_ops/triton/test_copy_kv_dest.py
deleted
100644 → 0
View file @
04244aaa
import
pytest
import
torch
from
packaging
import
version
try
:
pass
from
colossalai.kernel.triton.copy_kv_cache_dest
import
copy_kv_cache_to_dest
HAS_TRITON
=
True
except
ImportError
:
HAS_TRITON
=
False
print
(
"please install triton from https://github.com/openai/triton"
)
TRITON_CUDA_SUPPORT
=
version
.
parse
(
torch
.
version
.
cuda
)
>
version
.
parse
(
"11.4"
)
@
pytest
.
mark
.
skipif
(
not
TRITON_CUDA_SUPPORT
or
not
HAS_TRITON
,
reason
=
"triton requires cuda version to be higher than 11.4"
)
def
test_kv_cache_copy_op
():
B_NTX
=
32
*
2048
head_num
=
8
head_dim
=
64
cache
=
torch
.
randn
((
B_NTX
,
head_num
,
head_dim
),
device
=
"cuda"
,
dtype
=
torch
.
float16
)
dest_index
=
torch
.
arange
(
0
,
B_NTX
,
device
=
"cuda"
,
dtype
=
torch
.
int32
)
dest_data
=
torch
.
ones
((
B_NTX
,
head_num
,
head_dim
),
device
=
"cuda"
,
dtype
=
torch
.
float16
)
copy_kv_cache_to_dest
(
cache
,
dest_index
,
dest_data
)
assert
torch
.
allclose
(
cache
.
cpu
(),
dest_data
.
cpu
(),
rtol
=
1e-3
,
atol
=
1e-3
),
"copy_kv_cache_to_dest outputs from triton and torch are not matched"
if
__name__
==
"__main__"
:
test_kv_cache_copy_op
()
tests/test_infer_ops/triton/test_layernorm_triton.py
deleted
100644 → 0
View file @
04244aaa
import
pytest
import
torch
from
packaging
import
version
from
colossalai.kernel.triton
import
layer_norm
from
colossalai.testing.utils
import
parameterize
try
:
pass
HAS_TRITON
=
True
except
ImportError
:
HAS_TRITON
=
False
print
(
"please install triton from https://github.com/openai/triton"
)
TRITON_CUDA_SUPPORT
=
version
.
parse
(
torch
.
version
.
cuda
)
>
version
.
parse
(
"11.4"
)
@
pytest
.
mark
.
skipif
(
not
TRITON_CUDA_SUPPORT
or
not
HAS_TRITON
,
reason
=
"triton requires cuda version to be higher than 11.4"
)
@
parameterize
(
"M"
,
[
2
,
4
,
8
,
16
])
@
parameterize
(
"N"
,
[
64
,
128
])
def
test_layer_norm
(
M
,
N
):
dtype
=
torch
.
float16
eps
=
1e-5
x_shape
=
(
M
,
N
)
w_shape
=
(
x_shape
[
-
1
],)
weight
=
torch
.
rand
(
w_shape
,
dtype
=
dtype
,
device
=
"cuda"
)
bias
=
torch
.
rand
(
w_shape
,
dtype
=
dtype
,
device
=
"cuda"
)
x
=
-
2.3
+
0.5
*
torch
.
randn
(
x_shape
,
dtype
=
dtype
,
device
=
"cuda"
)
y_triton
=
layer_norm
(
x
,
weight
,
bias
,
eps
)
y_torch
=
torch
.
nn
.
functional
.
layer_norm
(
x
,
w_shape
,
weight
,
bias
,
eps
).
to
(
dtype
)
assert
y_triton
.
shape
==
y_torch
.
shape
assert
y_triton
.
dtype
==
y_torch
.
dtype
print
(
"max delta: "
,
torch
.
max
(
torch
.
abs
(
y_triton
-
y_torch
)))
assert
torch
.
allclose
(
y_triton
,
y_torch
,
atol
=
1e-2
,
rtol
=
0
)
if
__name__
==
"__main__"
:
test_layer_norm
()
tests/test_infer_ops/triton/test_llama_act_combine.py
deleted
100644 → 0
View file @
04244aaa
import
pytest
import
torch
from
packaging
import
version
from
torch
import
nn
from
colossalai.kernel.triton.llama_act_combine_kernel
import
LlamaActCombine
try
:
import
triton
HAS_TRITON
=
True
except
ImportError
:
HAS_TRITON
=
False
print
(
"please install triton from https://github.com/openai/triton"
)
TRITON_CUDA_SUPPORT
=
version
.
parse
(
torch
.
version
.
cuda
)
>
version
.
parse
(
'11.4'
)
BATCH_SIZE
=
4
SEQ_LEN
=
16
HIDDEN_SIZE
=
32
def
SwiGLU
(
x
):
"""Gated linear unit activation function.
Args:
x : input array
axis: the axis along which the split should be computed (default: -1)
"""
size
=
x
.
shape
[
-
1
]
assert
size
%
2
==
0
,
"axis size must be divisible by 2"
x1
,
x2
=
torch
.
split
(
x
,
size
//
2
,
-
1
)
return
x1
*
(
x2
*
torch
.
sigmoid
(
x2
.
to
(
torch
.
float32
)).
to
(
x
.
dtype
))
@
pytest
.
mark
.
skipif
(
not
(
HAS_TRITON
and
TRITON_CUDA_SUPPORT
),
reason
=
"requires triton"
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float32
,
torch
.
float16
])
def
test_llama_act_combine
(
dtype
:
str
):
x_gate
=
torch
.
randn
(
BATCH_SIZE
,
SEQ_LEN
,
HIDDEN_SIZE
*
2
,
dtype
=
dtype
).
cuda
()
x_gate_torch
=
nn
.
Parameter
(
x_gate
.
detach
().
clone
())
x_gate_kernel
=
nn
.
Parameter
(
x_gate
.
detach
().
clone
())
x_up
=
torch
.
randn
(
BATCH_SIZE
,
SEQ_LEN
,
HIDDEN_SIZE
,
dtype
=
dtype
).
cuda
()
x_up_torch
=
nn
.
Parameter
(
x_up
.
detach
().
clone
())
x_up_kernel
=
nn
.
Parameter
(
x_up
.
detach
().
clone
())
torch_out
=
SwiGLU
(
x_gate_torch
)
*
x_up_torch
kernel_out
=
LlamaActCombine
.
apply
(
x_gate_kernel
,
x_up_kernel
)
atol
=
1e-5
if
dtype
==
torch
.
float32
else
5e-2
assert
torch
.
allclose
(
torch_out
,
kernel_out
,
atol
=
atol
)
torch_out
.
mean
().
backward
()
kernel_out
.
mean
().
backward
()
assert
all
(
grad
is
not
None
for
grad
in
[
x_gate_torch
.
grad
,
x_up_torch
.
grad
,
x_gate_kernel
.
grad
,
x_up_kernel
.
grad
])
assert
torch
.
allclose
(
x_gate_torch
.
grad
,
x_gate_kernel
.
grad
,
atol
=
atol
)
assert
torch
.
allclose
(
x_up_torch
.
grad
,
x_up_kernel
.
grad
,
atol
=
atol
)
if
__name__
==
'__main__'
:
test_llama_act_combine
(
torch
.
float16
)
tests/test_infer_ops/triton/test_llama_context_attention.py
deleted
100644 → 0
View file @
04244aaa
import
pytest
import
torch
from
packaging
import
version
try
:
pass
from
colossalai.kernel.triton
import
llama_context_attn_fwd
from
tests.test_infer_ops.triton.kernel_utils
import
torch_context_attention
HAS_TRITON
=
True
except
ImportError
:
HAS_TRITON
=
False
print
(
"please install triton from https://github.com/openai/triton"
)
TRITON_CUDA_SUPPORT
=
version
.
parse
(
torch
.
version
.
cuda
)
>
version
.
parse
(
"11.4"
)
@
pytest
.
mark
.
skipif
(
not
TRITON_CUDA_SUPPORT
or
not
HAS_TRITON
,
reason
=
"triton requires cuda version to be higher than 11.4"
)
def
test_llama_context_attention
():
bs
=
4
head_num
=
8
seq_len
=
1024
head_dim
=
64
query
=
torch
.
randn
((
bs
*
seq_len
,
head_num
,
head_dim
),
dtype
=
torch
.
float16
,
device
=
"cuda"
)
k
=
torch
.
randn
((
bs
*
seq_len
,
head_num
,
head_dim
),
dtype
=
torch
.
float16
,
device
=
"cuda"
)
v
=
torch
.
randn
((
bs
*
seq_len
,
head_num
,
head_dim
),
dtype
=
torch
.
float16
,
device
=
"cuda"
)
max_input_len
=
seq_len
b_start
=
torch
.
zeros
((
bs
,),
device
=
"cuda"
,
dtype
=
torch
.
int32
)
b_len
=
torch
.
zeros
((
bs
,),
device
=
"cuda"
,
dtype
=
torch
.
int32
)
for
i
in
range
(
bs
):
b_start
[
i
]
=
i
*
seq_len
b_len
[
i
]
=
seq_len
o
=
torch
.
randn
((
bs
*
seq_len
,
head_num
,
head_dim
),
dtype
=
torch
.
float16
,
device
=
"cuda"
)
llama_context_attn_fwd
(
query
.
clone
(),
k
.
clone
(),
v
.
clone
(),
o
,
b_start
,
b_len
,
max_input_len
)
torch_out
=
torch_context_attention
(
query
.
clone
(),
k
.
clone
(),
v
.
clone
(),
bs
,
seq_len
,
head_num
,
head_dim
)
assert
torch
.
allclose
(
torch_out
.
cpu
(),
o
.
cpu
(),
rtol
=
1e-3
,
atol
=
1e-3
),
"outputs from triton and torch are not matched"
if
__name__
==
"__main__"
:
test_llama_context_attention
()
tests/test_infer_ops/triton/test_self_attention_nonfusion.py
deleted
100644 → 0
View file @
04244aaa
import
pytest
import
torch
import
torch.nn.functional
as
F
from
packaging
import
version
try
:
import
triton
from
colossalai.kernel.triton.qkv_matmul_kernel
import
qkv_gemm_4d_kernel
from
colossalai.kernel.triton.self_attention_nofusion
import
self_attention_compute_using_triton
HAS_TRITON
=
True
except
ImportError
:
HAS_TRITON
=
False
print
(
"please install triton from https://github.com/openai/triton"
)
TRITON_CUDA_SUPPORT
=
version
.
parse
(
torch
.
version
.
cuda
)
>
version
.
parse
(
"11.4"
)
@
pytest
.
mark
.
skipif
(
not
TRITON_CUDA_SUPPORT
or
not
HAS_TRITON
,
reason
=
"triton requires cuda version to be higher than 11.4"
)
def
test_qkv_matmul
():
qkv
=
torch
.
randn
((
4
,
24
,
64
*
3
),
device
=
"cuda"
,
dtype
=
torch
.
float16
)
scale
=
1.2
head_size
=
32
batches
=
qkv
.
shape
[
0
]
d_model
=
qkv
.
shape
[
-
1
]
//
3
num_of_heads
=
d_model
//
head_size
q
=
qkv
[:,
:,
:
d_model
]
k
=
qkv
[:,
:,
d_model
:
d_model
*
2
]
q
=
q
.
view
(
batches
,
-
1
,
num_of_heads
,
head_size
)
k
=
k
.
view
(
batches
,
-
1
,
num_of_heads
,
head_size
)
q_copy
=
q
.
clone
()
k_copy
=
k
.
clone
()
q
=
torch
.
transpose
(
q
,
1
,
2
).
contiguous
()
k
=
torch
.
transpose
(
k
,
1
,
2
).
contiguous
()
k
=
torch
.
transpose
(
k
,
2
,
3
).
contiguous
()
torch_ouput
=
torch
.
einsum
(
"bnij,bnjk->bnik"
,
q
,
k
)
torch_ouput
*=
1.2
q
,
k
=
q_copy
,
k_copy
batches
,
M
,
H
,
K
=
q
.
shape
N
=
k
.
shape
[
1
]
score_output
=
torch
.
empty
((
batches
,
H
,
M
,
N
),
device
=
q
.
device
,
dtype
=
q
.
dtype
)
grid
=
lambda
meta
:
(
batches
,
H
,
triton
.
cdiv
(
M
,
meta
[
"BLOCK_SIZE_M"
])
*
triton
.
cdiv
(
N
,
meta
[
"BLOCK_SIZE_N"
]),
)
K
=
q
.
shape
[
3
]
qkv_gemm_4d_kernel
[
grid
](
q
,
k
,
score_output
,
M
,
N
,
K
,
q
.
stride
(
0
),
q
.
stride
(
2
),
q
.
stride
(
1
),
q
.
stride
(
3
),
k
.
stride
(
0
),
k
.
stride
(
2
),
k
.
stride
(
3
),
k
.
stride
(
1
),
score_output
.
stride
(
0
),
score_output
.
stride
(
1
),
score_output
.
stride
(
2
),
score_output
.
stride
(
3
),
scale
=
scale
,
# currently manually setting, later on we can use auto-tune config to match best setting
BLOCK_SIZE_M
=
64
,
BLOCK_SIZE_N
=
32
,
BLOCK_SIZE_K
=
32
,
GROUP_SIZE_M
=
8
,
)
check
=
torch
.
allclose
(
torch_ouput
.
cpu
(),
score_output
.
cpu
(),
rtol
=
1e-3
,
atol
=
1e-5
)
assert
check
is
True
,
"the outputs of triton and torch are not matched"
def
self_attention_compute_using_torch
(
qkv
,
input_mask
,
scale
,
head_size
):
batches
=
qkv
.
shape
[
0
]
d_model
=
qkv
.
shape
[
-
1
]
//
3
num_of_heads
=
d_model
//
head_size
q
=
qkv
[:,
:,
:
d_model
]
k
=
qkv
[:,
:,
d_model
:
d_model
*
2
]
v
=
qkv
[:,
:,
d_model
*
2
:]
q
=
q
.
view
(
batches
,
-
1
,
num_of_heads
,
head_size
)
k
=
k
.
view
(
batches
,
-
1
,
num_of_heads
,
head_size
)
v
=
v
.
view
(
batches
,
-
1
,
num_of_heads
,
head_size
)
q
=
torch
.
transpose
(
q
,
1
,
2
).
contiguous
()
k
=
torch
.
transpose
(
k
,
1
,
2
).
contiguous
()
v
=
torch
.
transpose
(
v
,
1
,
2
).
contiguous
()
k
=
torch
.
transpose
(
k
,
-
1
,
-
2
).
contiguous
()
score_output
=
torch
.
einsum
(
"bnij,bnjk->bnik"
,
q
,
k
)
score_output
*=
scale
softmax_output
=
F
.
softmax
(
score_output
,
dim
=-
1
)
res
=
torch
.
einsum
(
"bnij,bnjk->bnik"
,
softmax_output
,
v
)
res
=
torch
.
transpose
(
res
,
1
,
2
)
res
=
res
.
contiguous
()
return
res
.
view
(
batches
,
-
1
,
d_model
),
score_output
,
softmax_output
@
pytest
.
mark
.
skipif
(
not
TRITON_CUDA_SUPPORT
or
not
HAS_TRITON
,
reason
=
"triton requires cuda version to be higher than 11.4"
)
def
test_self_atttention_test
():
qkv
=
torch
.
randn
((
4
,
24
,
64
*
3
),
device
=
"cuda"
,
dtype
=
torch
.
float16
)
data_output_torch
,
score_output_torch
,
softmax_output_torch
=
self_attention_compute_using_torch
(
qkv
.
clone
(),
input_mask
=
None
,
scale
=
1.2
,
head_size
=
32
)
data_output_triton
=
self_attention_compute_using_triton
(
qkv
.
clone
(),
alibi
=
None
,
head_size
=
32
,
scale
=
1.2
,
input_mask
=
None
,
layer_past
=
None
,
use_flash
=
False
,
triangular
=
True
,
)
check
=
torch
.
allclose
(
data_output_triton
.
cpu
(),
data_output_torch
.
cpu
(),
rtol
=
1e-4
,
atol
=
1e-2
)
assert
check
is
True
,
"the triton output is not matched with torch output"
if
__name__
==
"__main__"
:
test_qkv_matmul
()
test_self_atttention_test
()
tests/test_infer_ops/triton/test_softmax.py
deleted
100644 → 0
View file @
04244aaa
import
pytest
import
torch
from
packaging
import
version
from
torch
import
nn
try
:
from
colossalai.kernel.triton.softmax
import
softmax
HAS_TRITON
=
True
except
ImportError
:
HAS_TRITON
=
False
print
(
"please install triton from https://github.com/openai/triton"
)
TRITON_CUDA_SUPPORT
=
version
.
parse
(
torch
.
version
.
cuda
)
>
version
.
parse
(
"11.4"
)
@
pytest
.
mark
.
skipif
(
not
TRITON_CUDA_SUPPORT
or
not
HAS_TRITON
,
reason
=
"triton requires cuda version to be higher than 11.4"
)
def
test_softmax_op
():
data_samples
=
[
torch
.
randn
((
3
,
4
,
5
,
32
),
device
=
"cuda"
,
dtype
=
torch
.
float32
),
torch
.
randn
((
320
,
320
,
78
),
device
=
"cuda"
,
dtype
=
torch
.
float32
),
torch
.
randn
((
2345
,
4
,
5
,
64
),
device
=
"cuda"
,
dtype
=
torch
.
float16
),
]
for
data
in
data_samples
:
module
=
nn
.
Softmax
(
dim
=-
1
)
data_torch_out
=
module
(
data
)
data_triton_out
=
softmax
(
data
)
check
=
torch
.
allclose
(
data_torch_out
.
cpu
(),
data_triton_out
.
cpu
(),
rtol
=
1e-3
,
atol
=
1e-3
)
assert
check
is
True
,
"softmax outputs from triton and torch are not matched"
if
__name__
==
"__main__"
:
test_softmax_op
()
tests/test_infer_ops/triton/test_token_attn_fwd.py
deleted
100644 → 0
View file @
04244aaa
import
pytest
import
torch
from
packaging
import
version
try
:
from
colossalai.kernel.triton.token_attention_kernel
import
token_attention_fwd
HAS_TRITON
=
True
except
ImportError
:
HAS_TRITON
=
False
print
(
"please install triton from https://github.com/openai/triton"
)
import
importlib.util
HAS_LIGHTLLM_KERNEL
=
True
if
importlib
.
util
.
find_spec
(
"lightllm"
)
is
None
:
HAS_LIGHTLLM_KERNEL
=
False
TRITON_CUDA_SUPPORT
=
version
.
parse
(
torch
.
version
.
cuda
)
>=
version
.
parse
(
"11.6"
)
def
torch_att
(
xq
,
xk
,
xv
,
bs
,
seqlen
,
num_head
,
head_dim
):
xq
=
xq
.
view
(
bs
,
1
,
num_head
,
head_dim
)
xk
=
xk
.
view
(
bs
,
seqlen
,
num_head
,
head_dim
)
xv
=
xv
.
view
(
bs
,
seqlen
,
num_head
,
head_dim
)
logics
=
torch
.
sum
(
xq
*
xk
,
dim
=
3
,
keepdim
=
False
)
*
1
/
(
head_dim
**
0.5
)
prob
=
torch
.
softmax
(
logics
,
dim
=
1
)
prob
=
prob
.
view
(
bs
,
seqlen
,
num_head
,
1
)
return
torch
.
sum
(
prob
*
xv
,
dim
=
1
,
keepdim
=
False
)
@
pytest
.
mark
.
skipif
(
not
TRITON_CUDA_SUPPORT
or
not
HAS_TRITON
or
not
HAS_LIGHTLLM_KERNEL
,
reason
=
"triton requires cuda version to be higher than 11.4 or not install lightllm"
,
)
def
test
():
Z
,
head_num
,
seq_len
,
head_dim
=
22
,
112
//
8
,
2048
,
128
dtype
=
torch
.
float16
q
=
torch
.
empty
((
Z
,
head_num
,
head_dim
),
dtype
=
dtype
,
device
=
"cuda"
).
normal_
(
mean
=
0.1
,
std
=
0.2
)
k
=
torch
.
empty
((
Z
*
seq_len
,
head_num
,
head_dim
),
dtype
=
dtype
,
device
=
"cuda"
).
normal_
(
mean
=
0.4
,
std
=
0.2
)
v
=
torch
.
empty
((
Z
*
seq_len
,
head_num
,
head_dim
),
dtype
=
dtype
,
device
=
"cuda"
).
normal_
(
mean
=
0.3
,
std
=
0.2
)
o
=
torch
.
empty
((
Z
,
head_num
,
head_dim
),
dtype
=
dtype
,
device
=
"cuda"
).
normal_
(
mean
=
0.3
,
std
=
0.2
)
alibi
=
torch
.
zeros
((
head_num
,),
dtype
=
torch
.
float32
,
device
=
"cuda"
)
max_kv_cache_len
=
seq_len
kv_cache_start_loc
=
torch
.
zeros
((
Z
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
kv_cache_loc
=
torch
.
zeros
((
Z
,
seq_len
),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
kv_cache_seq_len
=
torch
.
ones
((
Z
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
kv_cache_seq_len
[:]
=
seq_len
kv_cache_start_loc
[
0
]
=
0
kv_cache_start_loc
[
1
]
=
seq_len
kv_cache_start_loc
[
2
]
=
2
*
seq_len
kv_cache_start_loc
[
3
]
=
3
*
seq_len
for
i
in
range
(
Z
):
kv_cache_loc
[
i
,
:]
=
torch
.
arange
(
i
*
seq_len
,
(
i
+
1
)
*
seq_len
,
dtype
=
torch
.
int32
,
device
=
"cuda"
)
token_attention_fwd
(
q
,
k
,
v
,
o
,
kv_cache_loc
,
kv_cache_start_loc
,
kv_cache_seq_len
,
max_kv_cache_len
,
alibi
=
alibi
)
torch_out
=
torch_att
(
q
,
k
,
v
,
Z
,
seq_len
,
head_num
,
head_dim
)
print
(
"max "
,
torch
.
max
(
torch
.
abs
(
torch_out
-
o
)))
print
(
"mean "
,
torch
.
mean
(
torch
.
abs
(
torch_out
-
o
)))
assert
torch
.
allclose
(
torch_out
,
o
,
atol
=
1e-2
,
rtol
=
0
)
if
__name__
==
"__main__"
:
test
()
tests/test_infer_ops/triton/test_token_softmax.py
deleted
100644 → 0
View file @
04244aaa
import
pytest
import
torch
from
packaging
import
version
try
:
pass
from
colossalai.kernel.triton.token_attention_kernel
import
token_attn_softmax_fwd
HAS_TRITON
=
True
except
ImportError
:
HAS_TRITON
=
False
print
(
"please install triton from https://github.com/openai/triton"
)
TRITON_CUDA_SUPPORT
=
version
.
parse
(
torch
.
version
.
cuda
)
>
version
.
parse
(
"11.4"
)
@
pytest
.
mark
.
skipif
(
not
TRITON_CUDA_SUPPORT
or
not
HAS_TRITON
,
reason
=
"triton requires cuda version to be higher than 11.4"
)
def
test_softmax
():
import
torch
batch_size
,
seq_len
,
head_num
,
head_dim
=
4
,
1025
,
12
,
128
dtype
=
torch
.
float16
Logics
=
torch
.
empty
((
head_num
,
batch_size
*
seq_len
),
dtype
=
dtype
,
device
=
"cuda"
).
normal_
(
mean
=
0.1
,
std
=
10
)
ProbOut
=
torch
.
empty
((
head_num
,
batch_size
*
seq_len
),
dtype
=
dtype
,
device
=
"cuda"
).
normal_
(
mean
=
0.4
,
std
=
0.2
)
kv_cache_start_loc
=
torch
.
zeros
((
batch_size
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
kv_cache_seq_len
=
torch
.
zeros
((
batch_size
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
for
i
in
range
(
batch_size
):
kv_cache_start_loc
[
i
]
=
i
*
seq_len
kv_cache_seq_len
[
i
]
=
seq_len
token_attn_softmax_fwd
(
Logics
,
kv_cache_start_loc
,
kv_cache_seq_len
,
ProbOut
,
seq_len
)
torch_out
=
Logics
.
reshape
(
head_num
*
batch_size
,
-
1
).
softmax
(
-
1
).
reshape
(
head_num
,
batch_size
*
seq_len
)
o
=
ProbOut
print
(
"max "
,
torch
.
max
(
torch
.
abs
(
torch_out
-
o
)))
print
(
"mean "
,
torch
.
mean
(
torch
.
abs
(
torch_out
-
o
)))
assert
torch
.
allclose
(
torch_out
,
o
,
atol
=
1e-2
,
rtol
=
0
)
if
__name__
==
"__main__"
:
test_softmax
()
tests/test_lazy/test_models.py
View file @
d69cd2eb
import
pytest
from
lazy_init_utils
import
SUPPORT_LAZY
,
check_lazy_init
from
tests.kit.model_zoo
import
model_zoo
,
IS_FAST_TEST
,
COMMON_MODELS
from
tests.kit.model_zoo
import
COMMON_MODELS
,
IS_FAST_TEST
,
model_zoo
@
pytest
.
mark
.
skipif
(
not
SUPPORT_LAZY
,
reason
=
"requires torch >= 1.12.0"
)
@
pytest
.
mark
.
parametrize
(
"subset"
,
[
COMMON_MODELS
]
if
IS_FAST_TEST
else
[
"torchvision"
,
"diffusers"
,
"timm"
,
"transformers"
,
"torchaudio"
,
"deepfm"
,
"dlrm"
])
@
pytest
.
mark
.
parametrize
(
"subset"
,
[
COMMON_MODELS
]
if
IS_FAST_TEST
else
[
"torchvision"
,
"diffusers"
,
"timm"
,
"transformers"
,
"torchaudio"
,
"deepfm"
,
"dlrm"
],
)
@
pytest
.
mark
.
parametrize
(
"default_device"
,
[
"cpu"
,
"cuda"
])
def
test_torchvision_models_lazy_init
(
subset
,
default_device
):
sub_model_zoo
=
model_zoo
.
get_sub_registry
(
subset
)
sub_model_zoo
=
model_zoo
.
get_sub_registry
(
subset
,
allow_empty
=
True
)
for
name
,
entry
in
sub_model_zoo
.
items
():
# TODO(ver217): lazy init does not support weight norm, skip these models
if
name
in
(
"torchaudio_wav2vec2_base"
,
"torchaudio_hubert_base"
)
or
name
.
startswith
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment