Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
f68eddfb
Unverified
Commit
f68eddfb
authored
Jan 13, 2022
by
ver217
Committed by
GitHub
Jan 13, 2022
Browse files
refactor kernel (#142)
parent
4a3d3446
Changes
24
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
79 additions
and
250 deletions
+79
-250
colossalai/nn/optimizer/fused_sgd.py
colossalai/nn/optimizer/fused_sgd.py
+1
-2
csrc/compat.h
csrc/compat.h
+0
-10
csrc/type_shim.h
csrc/type_shim.h
+0
-202
setup.py
setup.py
+78
-36
No files found.
colossalai/nn/optimizer/fused_sgd.py
View file @
f68eddfb
...
@@ -90,8 +90,7 @@ class FusedSGD(Optimizer):
...
@@ -90,8 +90,7 @@ class FusedSGD(Optimizer):
[
0
],
dtype
=
torch
.
int
,
device
=
self
.
param_groups
[
0
][
"params"
][
0
].
device
)
[
0
],
dtype
=
torch
.
int
,
device
=
self
.
param_groups
[
0
][
"params"
][
0
].
device
)
self
.
multi_tensor_sgd
=
colossal_C
.
multi_tensor_sgd
self
.
multi_tensor_sgd
=
colossal_C
.
multi_tensor_sgd
else
:
else
:
raise
RuntimeError
(
raise
RuntimeError
(
'FusedSGD requires cuda extensions'
)
'apex.optimizers.FusedSGD requires cuda extensions'
)
def
__setstate__
(
self
,
state
):
def
__setstate__
(
self
,
state
):
super
(
FusedSGD
,
self
).
__setstate__
(
state
)
super
(
FusedSGD
,
self
).
__setstate__
(
state
)
...
...
csrc/compat.h
deleted
100644 → 0
View file @
4a3d3446
// modified from https://github.com/NVIDIA/apex/blob/master/csrc/compat.h
#ifndef TORCH_CHECK
#define TORCH_CHECK AT_CHECK
#endif
#ifdef VERSION_GE_1_3
#define DATA_PTR data_ptr
#else
#define DATA_PTR data
#endif
\ No newline at end of file
csrc/type_shim.h
deleted
100644 → 0
View file @
4a3d3446
// modified from https://github.com/NVIDIA/apex/blob/master/csrc/type_shim.h
#include <ATen/ATen.h>
#include "compat.h"
// Forward/backward compatiblity hack around
// https://github.com/pytorch/pytorch/commit/3aeb78079bcd68282fe9117088e138b77318e288
// pending more future-proof guidance from upstream.
// struct TypeShim
// {
// const at::Type& payload;
// TypeShim(const at::Type& type) : payload(type) {}
// // Enable trivial conversion to a const at::Type& for pre-3aeb78
// operator const at::Type&(){ return payload; };
// // Enable dispatch switch statements to take *this directly for post-3aeb78
// //operator at::ScalarType(){ return payload.; };
// };
#define DISPATCH_FLOAT_AND_HALF(TYPE, LEVEL, NAME, ...) \
switch (TYPE) \
{ \
case at::ScalarType::Float: \
{ \
using scalar_t_##LEVEL = float; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Half: \
{ \
using scalar_t_##LEVEL = at::Half; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
}
#define DISPATCH_FLOAT_HALF_AND_BYTE(TYPE, LEVEL, NAME, ...) \
switch (TYPE) \
{ \
case at::ScalarType::Float: \
{ \
using scalar_t_##LEVEL = float; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Half: \
{ \
using scalar_t_##LEVEL = at::Half; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Byte: \
{ \
using scalar_t_##LEVEL = uint8_t; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
}
#define DISPATCH_DOUBLE_FLOAT_AND_HALF(TYPE, LEVEL, NAME, ...) \
switch (TYPE) \
{ \
case at::ScalarType::Double: \
{ \
using scalar_t_##LEVEL = double; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Float: \
{ \
using scalar_t_##LEVEL = float; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Half: \
{ \
using scalar_t_##LEVEL = at::Half; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
}
#define DISPATCH_DOUBLE_AND_FLOAT(TYPE, LEVEL, NAME, ...) \
switch (TYPE) \
{ \
case at::ScalarType::Double: \
{ \
using scalar_t_##LEVEL = double; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Float: \
{ \
using scalar_t_##LEVEL = float; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
}
template
<
typename
T
>
__device__
__forceinline__
T
reduce_block_into_lanes
(
T
*
x
,
T
val
,
int
lanes
=
1
,
bool
share_result
=
false
)
// lanes is intended to be <= 32.
{
int
tid
=
threadIdx
.
x
+
threadIdx
.
y
*
blockDim
.
x
;
int
blockSize
=
blockDim
.
x
*
blockDim
.
y
;
// blockSize is intended to be a multiple of 32.
if
(
blockSize
>=
64
)
{
x
[
tid
]
=
val
;
__syncthreads
();
}
#pragma unroll
for
(
int
i
=
(
blockSize
>>
1
);
i
>=
64
;
i
>>=
1
)
{
if
(
tid
<
i
)
x
[
tid
]
=
x
[
tid
]
+
x
[
tid
+
i
];
__syncthreads
();
}
T
final
;
if
(
tid
<
32
)
{
if
(
blockSize
>=
64
)
final
=
x
[
tid
]
+
x
[
tid
+
32
];
else
final
=
val
;
// __SYNCWARP();
#pragma unroll
for
(
int
i
=
16
;
i
>=
lanes
;
i
>>=
1
)
final
=
final
+
__shfl_down_sync
(
0xffffffff
,
final
,
i
);
}
if
(
share_result
)
{
if
(
tid
<
lanes
)
x
[
tid
]
=
final
;
// EpilogueOp
// Make sure the smem result is visible to all warps.
__syncthreads
();
}
return
final
;
}
template
<
typename
T
>
__device__
__forceinline__
T
reduce_block_into_lanes_max_op
(
T
*
x
,
T
val
,
int
lanes
=
1
,
bool
share_result
=
false
)
// lanes is intended to be <= 32.
{
int
tid
=
threadIdx
.
x
+
threadIdx
.
y
*
blockDim
.
x
;
int
blockSize
=
blockDim
.
x
*
blockDim
.
y
;
// blockSize is intended to be a multiple of 32.
if
(
blockSize
>=
64
)
{
x
[
tid
]
=
val
;
__syncthreads
();
}
#pragma unroll
for
(
int
i
=
(
blockSize
>>
1
);
i
>=
64
;
i
>>=
1
)
{
if
(
tid
<
i
)
x
[
tid
]
=
fmaxf
(
fabsf
(
x
[
tid
]),
fabsf
(
x
[
tid
+
i
]));
__syncthreads
();
}
T
final
;
if
(
tid
<
32
)
{
if
(
blockSize
>=
64
)
final
=
fmaxf
(
fabsf
(
x
[
tid
]),
fabsf
(
x
[
tid
+
32
]));
else
final
=
val
;
// __SYNCWARP();
#pragma unroll
for
(
int
i
=
16
;
i
>=
lanes
;
i
>>=
1
)
final
=
fmaxf
(
fabsf
(
final
),
fabsf
(
__shfl_down_sync
(
0xffffffff
,
final
,
i
)));
}
if
(
share_result
)
{
if
(
tid
<
lanes
)
x
[
tid
]
=
final
;
// EpilogueOp
// Make sure the smem result is visible to all warps.
__syncthreads
();
}
return
final
;
}
\ No newline at end of file
setup.py
View file @
f68eddfb
...
@@ -11,8 +11,7 @@ this_dir = os.path.dirname(os.path.abspath(__file__))
...
@@ -11,8 +11,7 @@ this_dir = os.path.dirname(os.path.abspath(__file__))
def
get_cuda_bare_metal_version
(
cuda_dir
):
def
get_cuda_bare_metal_version
(
cuda_dir
):
raw_output
=
subprocess
.
check_output
(
raw_output
=
subprocess
.
check_output
([
cuda_dir
+
"/bin/nvcc"
,
"-V"
],
universal_newlines
=
True
)
[
cuda_dir
+
"/bin/nvcc"
,
"-V"
],
universal_newlines
=
True
)
output
=
raw_output
.
split
()
output
=
raw_output
.
split
()
release_idx
=
output
.
index
(
"release"
)
+
1
release_idx
=
output
.
index
(
"release"
)
+
1
release
=
output
[
release_idx
].
split
(
"."
)
release
=
output
[
release_idx
].
split
(
"."
)
...
@@ -23,8 +22,7 @@ def get_cuda_bare_metal_version(cuda_dir):
...
@@ -23,8 +22,7 @@ def get_cuda_bare_metal_version(cuda_dir):
def
check_cuda_torch_binary_vs_bare_metal
(
cuda_dir
):
def
check_cuda_torch_binary_vs_bare_metal
(
cuda_dir
):
raw_output
,
bare_metal_major
,
bare_metal_minor
=
get_cuda_bare_metal_version
(
raw_output
,
bare_metal_major
,
bare_metal_minor
=
get_cuda_bare_metal_version
(
cuda_dir
)
cuda_dir
)
torch_binary_major
=
torch
.
version
.
cuda
.
split
(
"."
)[
0
]
torch_binary_major
=
torch
.
version
.
cuda
.
split
(
"."
)[
0
]
torch_binary_minor
=
torch
.
version
.
cuda
.
split
(
"."
)[
1
]
torch_binary_minor
=
torch
.
version
.
cuda
.
split
(
"."
)[
1
]
...
@@ -40,6 +38,13 @@ def check_cuda_torch_binary_vs_bare_metal(cuda_dir):
...
@@ -40,6 +38,13 @@ def check_cuda_torch_binary_vs_bare_metal(cuda_dir):
"You can try commenting out this check (at your own risk)."
)
"You can try commenting out this check (at your own risk)."
)
def
append_nvcc_threads
(
nvcc_extra_args
):
_
,
bare_metal_major
,
bare_metal_minor
=
get_cuda_bare_metal_version
(
CUDA_HOME
)
if
int
(
bare_metal_major
)
>=
11
and
int
(
bare_metal_minor
)
>=
2
:
return
nvcc_extra_args
+
[
"--threads"
,
"4"
]
return
nvcc_extra_args
def
fetch_requirements
(
path
):
def
fetch_requirements
(
path
):
with
open
(
path
,
'r'
)
as
fd
:
with
open
(
path
,
'r'
)
as
fd
:
return
[
r
.
strip
()
for
r
in
fd
.
readlines
()]
return
[
r
.
strip
()
for
r
in
fd
.
readlines
()]
...
@@ -67,8 +72,8 @@ print("\n\ntorch.__version__ = {}\n\n".format(torch.__version__))
...
@@ -67,8 +72,8 @@ print("\n\ntorch.__version__ = {}\n\n".format(torch.__version__))
TORCH_MAJOR
=
int
(
torch
.
__version__
.
split
(
'.'
)[
0
])
TORCH_MAJOR
=
int
(
torch
.
__version__
.
split
(
'.'
)[
0
])
TORCH_MINOR
=
int
(
torch
.
__version__
.
split
(
'.'
)[
1
])
TORCH_MINOR
=
int
(
torch
.
__version__
.
split
(
'.'
)[
1
])
if
TORCH_MAJOR
==
0
and
TORCH_MINOR
<
4
:
if
TORCH_MAJOR
<
1
or
(
TORCH_MAJOR
==
1
and
TORCH_MINOR
<
8
)
:
raise
RuntimeError
(
"Colossal-AI requires Pytorch
0.4
or newer.
\n
"
+
raise
RuntimeError
(
"Colossal-AI requires Pytorch
1.8
or newer.
\n
"
+
"The latest stable release can be obtained from https://pytorch.org/"
)
"The latest stable release can be obtained from https://pytorch.org/"
)
cmdclass
=
{}
cmdclass
=
{}
...
@@ -79,22 +84,9 @@ ext_modules = []
...
@@ -79,22 +84,9 @@ ext_modules = []
# and
# and
# https://github.com/NVIDIA/apex/issues/456
# https://github.com/NVIDIA/apex/issues/456
# https://github.com/pytorch/pytorch/commit/eb7b39e02f7d75c26d8a795ea8c7fd911334da7e#diff-4632522f237f1e4e728cb824300403ac
# https://github.com/pytorch/pytorch/commit/eb7b39e02f7d75c26d8a795ea8c7fd911334da7e#diff-4632522f237f1e4e728cb824300403ac
version_ge_1_1
=
[]
version_dependent_macros
=
[
'-DVERSION_GE_1_1'
,
'-DVERSION_GE_1_3'
,
'-DVERSION_GE_1_5'
]
if
(
TORCH_MAJOR
>
1
)
or
(
TORCH_MAJOR
==
1
and
TORCH_MINOR
>
0
):
version_ge_1_1
=
[
'-DVERSION_GE_1_1'
]
version_ge_1_3
=
[]
if
(
TORCH_MAJOR
>
1
)
or
(
TORCH_MAJOR
==
1
and
TORCH_MINOR
>
2
):
version_ge_1_3
=
[
'-DVERSION_GE_1_3'
]
version_ge_1_5
=
[]
if
(
TORCH_MAJOR
>
1
)
or
(
TORCH_MAJOR
==
1
and
TORCH_MINOR
>
4
):
version_ge_1_5
=
[
'-DVERSION_GE_1_5'
]
version_dependent_macros
=
version_ge_1_1
+
version_ge_1_3
+
version_ge_1_5
if
"--cuda_ext"
in
sys
.
argv
:
if
"--cuda_ext"
in
sys
.
argv
:
if
TORCH_MAJOR
==
0
:
raise
RuntimeError
(
"--cuda_ext requires Pytorch 1.0 or later, "
"found torch.__version__ = {}"
.
format
(
torch
.
__version__
))
sys
.
argv
.
remove
(
"--cuda_ext"
)
sys
.
argv
.
remove
(
"--cuda_ext"
)
if
CUDA_HOME
is
None
:
if
CUDA_HOME
is
None
:
...
@@ -103,19 +95,66 @@ if "--cuda_ext" in sys.argv:
...
@@ -103,19 +95,66 @@ if "--cuda_ext" in sys.argv:
else
:
else
:
check_cuda_torch_binary_vs_bare_metal
(
CUDA_HOME
)
check_cuda_torch_binary_vs_bare_metal
(
CUDA_HOME
)
ext_modules
.
append
(
def
cuda_ext_helper
(
name
,
sources
,
extra_cuda_flags
):
CUDAExtension
(
name
=
'colossal_C'
,
return
CUDAExtension
(
name
=
name
,
sources
=
[
'csrc/colossal_C_frontend.cpp'
,
sources
=
[
os
.
path
.
join
(
'colossalai/kernel/cuda_native/csrc'
,
path
)
for
path
in
sources
],
'csrc/multi_tensor_sgd_kernel.cu'
,
include_dirs
=
[
os
.
path
.
join
(
'csrc/multi_tensor_scale_kernel.cu'
,
this_dir
,
'colossalai/kernel/cuda_native/csrc/kernels/include'
)],
'csrc/multi_tensor_adam.cu'
,
extra_compile_args
=
{
'cxx'
:
[
'-O3'
]
+
version_dependent_macros
,
'csrc/multi_tensor_l2norm_kernel.cu'
,
'nvcc'
:
append_nvcc_threads
([
'-O3'
,
'csrc/multi_tensor_lamb.cu'
],
'--use_fast_math'
]
+
version_dependent_macros
+
extra_cuda_flags
)})
extra_compile_args
=
{
'cxx'
:
[
'-O3'
]
+
version_dependent_macros
,
'nvcc'
:
[
'-lineinfo'
,
ext_modules
.
append
(
cuda_ext_helper
(
'colossal_C'
,
'-O3'
,
[
'colossal_C_frontend.cpp'
,
# '--resource-usage',
'multi_tensor_sgd_kernel.cu'
,
'--use_fast_math'
]
+
version_dependent_macros
}))
'multi_tensor_scale_kernel.cu'
,
'multi_tensor_adam.cu'
,
'multi_tensor_l2norm_kernel.cu'
,
'multi_tensor_lamb.cu'
],
[
'-lineinfo'
]))
cc_flag
=
[
'-gencode'
,
'arch=compute_70,code=sm_70'
]
_
,
bare_metal_major
,
_
=
get_cuda_bare_metal_version
(
CUDA_HOME
)
if
int
(
bare_metal_major
)
>=
11
:
cc_flag
.
append
(
'-gencode'
)
cc_flag
.
append
(
'arch=compute_80,code=sm_80'
)
extra_cuda_flags
=
[
'-U__CUDA_NO_HALF_OPERATORS__'
,
'-U__CUDA_NO_HALF_CONVERSIONS__'
,
'--expt-relaxed-constexpr'
,
'--expt-extended-lambda'
]
ext_modules
.
append
(
cuda_ext_helper
(
'colossal_scaled_upper_triang_masked_softmax'
,
[
'scaled_upper_triang_masked_softmax.cpp'
,
'scaled_upper_triang_masked_softmax_cuda.cu'
],
extra_cuda_flags
+
cc_flag
))
ext_modules
.
append
(
cuda_ext_helper
(
'colossal_scaled_masked_softmax'
,
[
'scaled_masked_softmax.cpp'
,
'scaled_masked_softmax_cuda.cu'
],
extra_cuda_flags
+
cc_flag
))
extra_cuda_flags
=
[
'-maxrregcount=50'
]
ext_modules
.
append
(
cuda_ext_helper
(
'colossal_layer_norm_cuda'
,
[
'layer_norm_cuda.cpp'
,
'layer_norm_cuda_kernel.cu'
],
extra_cuda_flags
+
cc_flag
))
extra_cuda_flags
=
[
'-std=c++14'
,
'-U__CUDA_NO_HALF_OPERATORS__'
,
'-U__CUDA_NO_HALF_CONVERSIONS__'
,
'-U__CUDA_NO_HALF2_OPERATORS__'
,
'-DTHRUST_IGNORE_CUB_VERSION_CHECK'
]
ext_modules
.
append
(
cuda_ext_helper
(
'colossal_multihead_attention'
,
[
'multihead_attention_1d.cpp'
,
'kernels/cublas_wrappers.cu'
,
'kernels/transform_kernels.cu'
,
'kernels/dropout_kernels.cu'
,
'kernels/normalize_kernels.cu'
,
'kernels/softmax_kernels.cu'
,
'kernels/general_kernels.cu'
,
'kernels/cuda_util.cu'
],
extra_cuda_flags
+
cc_flag
))
install_requires
=
fetch_requirements
(
'requirements/requirements.txt'
)
install_requires
=
fetch_requirements
(
'requirements/requirements.txt'
)
...
@@ -123,14 +162,17 @@ install_requires = fetch_requirements('requirements/requirements.txt')
...
@@ -123,14 +162,17 @@ install_requires = fetch_requirements('requirements/requirements.txt')
setup
(
setup
(
name
=
'colossalai'
,
name
=
'colossalai'
,
version
=
'0.0.1-beta'
,
version
=
'0.0.1-beta'
,
packages
=
find_packages
(
exclude
=
(
'csrc'
,
packages
=
find_packages
(
exclude
=
(
'benchmark'
,
'docker'
,
'tests'
,
'tests'
,
'docs'
,
'docs'
,
'examples'
,
'tests'
,
'tests'
,
'scripts'
,
'requirements'
,
'*.egg-info'
,)),
'*.egg-info'
,)),
description
=
'An integrated large-scale model training system with efficient parallelization techniques'
,
description
=
'An integrated large-scale model training system with efficient parallelization techniques'
,
ext_modules
=
ext_modules
,
ext_modules
=
ext_modules
,
cmdclass
=
{
'build_ext'
:
BuildExtension
}
if
ext_modules
else
{},
cmdclass
=
{
'build_ext'
:
BuildExtension
}
if
ext_modules
else
{},
package_data
=
{
'colossalai'
:
[
'kernel/cuda_native/csrc/*'
]},
install_requires
=
install_requires
,
install_requires
=
install_requires
,
)
)
\ No newline at end of file
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment