Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
58580b50
Unverified
Commit
58580b50
authored
May 17, 2022
by
ver217
Committed by
GitHub
May 17, 2022
Browse files
Revert "[NFC] Hotfix/format (#984)" (#986)
This reverts commit
0772828f
.
parent
0772828f
Changes
35
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
260 additions
and
206 deletions
+260
-206
colossalai/__init__.py
colossalai/__init__.py
+0
-1
colossalai/builder/pipeline.py
colossalai/builder/pipeline.py
+4
-3
colossalai/kernel/cuda_native/csrc/cpu_adam.cpp
colossalai/kernel/cuda_native/csrc/cpu_adam.cpp
+48
-31
colossalai/kernel/cuda_native/csrc/cpu_adam.h
colossalai/kernel/cuda_native/csrc/cpu_adam.h
+13
-19
colossalai/kernel/cuda_native/csrc/kernels/cross_entropy.cu
colossalai/kernel/cuda_native/csrc/kernels/cross_entropy.cu
+1
-1
colossalai/kernel/cuda_native/csrc/kernels/dropout_kernels.cu
...ssalai/kernel/cuda_native/csrc/kernels/dropout_kernels.cu
+30
-16
colossalai/kernel/cuda_native/csrc/kernels/general_kernels.cu
...ssalai/kernel/cuda_native/csrc/kernels/general_kernels.cu
+2
-2
colossalai/kernel/cuda_native/csrc/kernels/include/block_reduce.h
...ai/kernel/cuda_native/csrc/kernels/include/block_reduce.h
+20
-21
colossalai/kernel/cuda_native/csrc/kernels/include/context.h
colossalai/kernel/cuda_native/csrc/kernels/include/context.h
+2
-2
colossalai/kernel/cuda_native/csrc/kernels/include/cross_entropy_layer.h
...el/cuda_native/csrc/kernels/include/cross_entropy_layer.h
+3
-4
colossalai/kernel/cuda_native/csrc/kernels/include/cuda_util.h
...salai/kernel/cuda_native/csrc/kernels/include/cuda_util.h
+3
-4
colossalai/kernel/cuda_native/csrc/kernels/include/dropout.h
colossalai/kernel/cuda_native/csrc/kernels/include/dropout.h
+3
-5
colossalai/kernel/cuda_native/csrc/kernels/include/feed_forward.h
...ai/kernel/cuda_native/csrc/kernels/include/feed_forward.h
+4
-6
colossalai/kernel/cuda_native/csrc/kernels/include/softmax.h
colossalai/kernel/cuda_native/csrc/kernels/include/softmax.h
+3
-4
colossalai/kernel/cuda_native/csrc/kernels/normalize_kernels.cu
...alai/kernel/cuda_native/csrc/kernels/normalize_kernels.cu
+1
-0
colossalai/kernel/cuda_native/csrc/kernels/softmax_kernels.cu
...ssalai/kernel/cuda_native/csrc/kernels/softmax_kernels.cu
+6
-4
colossalai/kernel/cuda_native/csrc/layer_norm_cuda.cpp
colossalai/kernel/cuda_native/csrc/layer_norm_cuda.cpp
+12
-12
colossalai/kernel/cuda_native/csrc/moe_cuda.cpp
colossalai/kernel/cuda_native/csrc/moe_cuda.cpp
+16
-14
colossalai/kernel/cuda_native/csrc/moe_cuda_kernel.cu
colossalai/kernel/cuda_native/csrc/moe_cuda_kernel.cu
+52
-25
colossalai/kernel/cuda_native/csrc/multi_tensor_l2norm_kernel.cu
...lai/kernel/cuda_native/csrc/multi_tensor_l2norm_kernel.cu
+37
-32
No files found.
colossalai/__init__.py
View file @
58580b50
...
@@ -2,4 +2,3 @@ from .initialize import (initialize, launch, launch_from_openmpi,
...
@@ -2,4 +2,3 @@ from .initialize import (initialize, launch, launch_from_openmpi,
launch_from_slurm
,
launch_from_torch
,
get_default_parser
)
launch_from_slurm
,
launch_from_torch
,
get_default_parser
)
__version__
=
'0.0.1'
__version__
=
'0.0.1'
colossalai/builder/pipeline.py
View file @
58580b50
...
@@ -251,9 +251,9 @@ def build_pipeline_model(layers: nn.Sequential, num_chunks: int = 1, verbose: bo
...
@@ -251,9 +251,9 @@ def build_pipeline_model(layers: nn.Sequential, num_chunks: int = 1, verbose: bo
partitions
=
partition_uniform
(
len
(
layers
),
pipeline_parallel_size
,
num_chunks
)
partitions
=
partition_uniform
(
len
(
layers
),
pipeline_parallel_size
,
num_chunks
)
module_list
=
[]
module_list
=
[]
for
start
,
end
in
partitions
[
pipeline_rank
]:
for
start
,
end
in
partitions
[
pipeline_rank
]:
module_list
.
append
(
module_list
.
append
(
nn
.
Sequential
(
*
[
nn
.
Identity
()
for
_
in
range
(
start
)],
nn
.
Sequential
(
*
[
nn
.
Identity
()
for
_
in
range
(
start
)],
*
layers
[
start
:
end
],
*
layers
[
start
:
end
],
*
[
nn
.
Identity
()
for
_
in
range
(
len
(
layers
)
-
end
)]))
*
[
nn
.
Identity
()
for
_
in
range
(
len
(
layers
)
-
end
)]))
if
verbose
:
if
verbose
:
logger
=
get_dist_logger
()
logger
=
get_dist_logger
()
logger
.
info
(
f
'Total
{
len
(
layers
)
}
layers'
,
ranks
=
[
0
])
logger
.
info
(
f
'Total
{
len
(
layers
)
}
layers'
,
ranks
=
[
0
])
...
@@ -264,3 +264,4 @@ def build_pipeline_model(layers: nn.Sequential, num_chunks: int = 1, verbose: bo
...
@@ -264,3 +264,4 @@ def build_pipeline_model(layers: nn.Sequential, num_chunks: int = 1, verbose: bo
log_str
+=
'
\n
'
.
join
([
str
(
layer
)
for
layer
in
layers
[
start
:
end
]])
+
'
\n
'
log_str
+=
'
\n
'
.
join
([
str
(
layer
)
for
layer
in
layers
[
start
:
end
]])
+
'
\n
'
logger
.
info
(
log_str
,
ranks
=
[
0
])
logger
.
info
(
log_str
,
ranks
=
[
0
])
return
nn
.
ModuleList
(
module_list
)
if
len
(
module_list
)
>
1
else
module_list
[
0
]
return
nn
.
ModuleList
(
module_list
)
if
len
(
module_list
)
>
1
else
module_list
[
0
]
\ No newline at end of file
colossalai/kernel/cuda_native/csrc/cpu_adam.cpp
View file @
58580b50
...
@@ -20,14 +20,12 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
...
@@ -20,14 +20,12 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE
SOFTWARE
*/
*/
#include "cpu_adam.h"
#include "cpu_adam.h"
#include <iostream>
#include <math.h>
#include <math.h>
#include <memory>
#include <omp.h>
#include <omp.h>
#include <string.h>
#include <string.h>
#include <torch/extension.h>
#include <torch/extension.h>
#include <iostream>
#include <memory>
#include <type_traits>
#include <type_traits>
#include <unordered_map>
#include <unordered_map>
...
@@ -84,7 +82,8 @@ void Adam_Optimizer::Step_1(float *_params, float *grads, float *_exp_avg,
...
@@ -84,7 +82,8 @@ void Adam_Optimizer::Step_1(float *_params, float *grads, float *_exp_avg,
for
(
size_t
t
=
0
;
t
<
rounded_size
;
t
+=
TILE
)
{
for
(
size_t
t
=
0
;
t
<
rounded_size
;
t
+=
TILE
)
{
size_t
copy_size
=
TILE
;
size_t
copy_size
=
TILE
;
if
((
t
+
TILE
)
>
rounded_size
)
copy_size
=
rounded_size
-
t
;
if
((
t
+
TILE
)
>
rounded_size
)
copy_size
=
rounded_size
-
t
;
size_t
offset
=
copy_size
+
t
;
size_t
offset
=
copy_size
+
t
;
#pragma omp parallel for
#pragma omp parallel for
...
@@ -146,7 +145,8 @@ void Adam_Optimizer::Step_1(float *_params, float *grads, float *_exp_avg,
...
@@ -146,7 +145,8 @@ void Adam_Optimizer::Step_1(float *_params, float *grads, float *_exp_avg,
if
(
_param_size
>
rounded_size
)
{
if
(
_param_size
>
rounded_size
)
{
for
(
size_t
t
=
rounded_size
;
t
<
_param_size
;
t
+=
TILE
)
{
for
(
size_t
t
=
rounded_size
;
t
<
_param_size
;
t
+=
TILE
)
{
size_t
copy_size
=
TILE
;
size_t
copy_size
=
TILE
;
if
((
t
+
TILE
)
>
_param_size
)
copy_size
=
_param_size
-
t
;
if
((
t
+
TILE
)
>
_param_size
)
copy_size
=
_param_size
-
t
;
size_t
offset
=
copy_size
+
t
;
size_t
offset
=
copy_size
+
t
;
#pragma omp parallel for
#pragma omp parallel for
...
@@ -235,7 +235,8 @@ void Adam_Optimizer::Step_4(float *_params, float *grads, float *_exp_avg,
...
@@ -235,7 +235,8 @@ void Adam_Optimizer::Step_4(float *_params, float *grads, float *_exp_avg,
for
(
size_t
t
=
0
;
t
<
rounded_size
;
t
+=
TILE
)
{
for
(
size_t
t
=
0
;
t
<
rounded_size
;
t
+=
TILE
)
{
size_t
copy_size
=
TILE
;
size_t
copy_size
=
TILE
;
if
((
t
+
TILE
)
>
rounded_size
)
copy_size
=
rounded_size
-
t
;
if
((
t
+
TILE
)
>
rounded_size
)
copy_size
=
rounded_size
-
t
;
size_t
offset
=
copy_size
+
t
;
size_t
offset
=
copy_size
+
t
;
#pragma omp parallel for
#pragma omp parallel for
...
@@ -320,6 +321,7 @@ int create_adam_optimizer(int optimizer_id, float alpha = 1e-3,
...
@@ -320,6 +321,7 @@ int create_adam_optimizer(int optimizer_id, float alpha = 1e-3,
s_optimizers
[
optimizer_id
]
=
opt
;
s_optimizers
[
optimizer_id
]
=
opt
;
if
(
should_log
)
{
if
(
should_log
)
{
std
::
string
avx_type
=
""
;
std
::
string
avx_type
=
""
;
#if defined(__AVX512__)
#if defined(__AVX512__)
avx_type
=
"AVX512"
;
avx_type
=
"AVX512"
;
...
@@ -384,7 +386,8 @@ void Adam_Optimizer::Step_8(float *_params, float *grads, float *_exp_avg,
...
@@ -384,7 +386,8 @@ void Adam_Optimizer::Step_8(float *_params, float *grads, float *_exp_avg,
for
(
size_t
t
=
0
;
t
<
rounded_size
;
t
+=
TILE
)
{
for
(
size_t
t
=
0
;
t
<
rounded_size
;
t
+=
TILE
)
{
size_t
copy_size
=
TILE
;
size_t
copy_size
=
TILE
;
if
((
t
+
TILE
)
>
rounded_size
)
copy_size
=
rounded_size
-
t
;
if
((
t
+
TILE
)
>
rounded_size
)
copy_size
=
rounded_size
-
t
;
size_t
offset
=
copy_size
+
t
;
size_t
offset
=
copy_size
+
t
;
#pragma omp parallel for
#pragma omp parallel for
...
@@ -460,29 +463,43 @@ void Adam_Optimizer::Step_8(float *_params, float *grads, float *_exp_avg,
...
@@ -460,29 +463,43 @@ void Adam_Optimizer::Step_8(float *_params, float *grads, float *_exp_avg,
grad_half_precision
,
loss_scale
);
grad_half_precision
,
loss_scale
);
}
}
int
adam_step
(
int
optimizer_id
,
size_t
step
,
float
lr
,
float
beta1
,
float
beta2
,
int
adam_step
(
int
optimizer_id
,
float
epsilon
,
float
weight_decay
,
bool
bias_correction
,
size_t
step
,
torch
::
Tensor
&
params
,
torch
::
Tensor
&
grads
,
float
lr
,
torch
::
Tensor
&
exp_avg
,
torch
::
Tensor
&
exp_avg_sq
,
float
beta1
,
float
loss_scale
)
{
float
beta2
,
auto
params_c
=
params
.
contiguous
();
float
epsilon
,
auto
grads_c
=
grads
.
contiguous
();
float
weight_decay
,
auto
exp_avg_c
=
exp_avg
.
contiguous
();
bool
bias_correction
,
auto
exp_avg_sq_c
=
exp_avg_sq
.
contiguous
();
torch
::
Tensor
&
params
,
torch
::
Tensor
&
grads
,
float
*
params_ptr
=
(
float
*
)
params_c
.
data_ptr
();
torch
::
Tensor
&
exp_avg
,
float
*
grads_ptr
=
(
float
*
)
grads_c
.
data_ptr
();
torch
::
Tensor
&
exp_avg_sq
,
float
*
exp_avg_ptr
=
(
float
*
)
exp_avg_c
.
data_ptr
();
float
loss_scale
)
float
*
exp_avg_sq_ptr
=
(
float
*
)
exp_avg_sq_c
.
data_ptr
();
{
std
::
shared_ptr
<
Adam_Optimizer
>
opt
=
auto
params_c
=
params
.
contiguous
();
std
::
static_pointer_cast
<
Adam_Optimizer
>
(
s_optimizers
[
optimizer_id
]);
auto
grads_c
=
grads
.
contiguous
();
opt
->
IncrementStep
(
step
,
beta1
,
beta2
);
auto
exp_avg_c
=
exp_avg
.
contiguous
();
opt
->
update_state
(
lr
,
epsilon
,
weight_decay
,
bias_correction
);
auto
exp_avg_sq_c
=
exp_avg_sq
.
contiguous
();
opt
->
Step_8
(
params_ptr
,
grads_ptr
,
exp_avg_ptr
,
exp_avg_sq_ptr
,
params_c
.
numel
(),
(
params
.
options
().
dtype
()
==
at
::
kHalf
),
float
*
params_ptr
=
(
float
*
)
params_c
.
data_ptr
();
(
grads
.
options
().
dtype
()
==
at
::
kHalf
),
loss_scale
);
float
*
grads_ptr
=
(
float
*
)
grads_c
.
data_ptr
();
float
*
exp_avg_ptr
=
(
float
*
)
exp_avg_c
.
data_ptr
();
return
0
;
float
*
exp_avg_sq_ptr
=
(
float
*
)
exp_avg_sq_c
.
data_ptr
();
std
::
shared_ptr
<
Adam_Optimizer
>
opt
=
std
::
static_pointer_cast
<
Adam_Optimizer
>
(
s_optimizers
[
optimizer_id
]);
opt
->
IncrementStep
(
step
,
beta1
,
beta2
);
opt
->
update_state
(
lr
,
epsilon
,
weight_decay
,
bias_correction
);
opt
->
Step_8
(
params_ptr
,
grads_ptr
,
exp_avg_ptr
,
exp_avg_sq_ptr
,
params_c
.
numel
(),
(
params
.
options
().
dtype
()
==
at
::
kHalf
),
(
grads
.
options
().
dtype
()
==
at
::
kHalf
),
loss_scale
);
return
0
;
}
}
int
destroy_adam_optimizer
(
int
optimizer_id
)
{
int
destroy_adam_optimizer
(
int
optimizer_id
)
{
...
...
colossalai/kernel/cuda_native/csrc/cpu_adam.h
View file @
58580b50
...
@@ -48,10 +48,10 @@ SOFTWARE
...
@@ -48,10 +48,10 @@ SOFTWARE
#define SIMD_FMA(x, y, c) _mm512_fmadd_ps(x, y, c)
#define SIMD_FMA(x, y, c) _mm512_fmadd_ps(x, y, c)
#define SIMD_SQRT(x) _mm512_sqrt_ps(x)
#define SIMD_SQRT(x) _mm512_sqrt_ps(x)
#define SIMD_DIV(x, y) _mm512_div_ps(x, y)
#define SIMD_DIV(x, y) _mm512_div_ps(x, y)
#define SIMD_LOAD_HALF(x) \
#define SIMD_LOAD_HALF(x)
\
_mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)(x)))
_mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)(x)))
#define SIMD_STORE_HALF(x, d) \
#define SIMD_STORE_HALF(x, d)
\
_mm256_store_ps( \
_mm256_store_ps(
\
x, _mm256_castsi256_ps(_mm512_cvtps_ph(d, _MM_FROUND_TO_NEAREST_INT)))
x, _mm256_castsi256_ps(_mm512_cvtps_ph(d, _MM_FROUND_TO_NEAREST_INT)))
#elif defined(__AVX256__) or defined(__AVX2__)
#elif defined(__AVX256__) or defined(__AVX2__)
...
@@ -66,8 +66,8 @@ SOFTWARE
...
@@ -66,8 +66,8 @@ SOFTWARE
#define SIMD_SQRT(x) _mm256_sqrt_ps(x)
#define SIMD_SQRT(x) _mm256_sqrt_ps(x)
#define SIMD_DIV(x, y) _mm256_div_ps(x, y)
#define SIMD_DIV(x, y) _mm256_div_ps(x, y)
#define SIMD_LOAD_HALF(x) _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)(x)))
#define SIMD_LOAD_HALF(x) _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)(x)))
#define SIMD_STORE_HALF(x, d) \
#define SIMD_STORE_HALF(x, d)
\
_mm_store_ps( \
_mm_store_ps(
\
x, _mm_castsi128_ps(_mm256_cvtps_ph(d, _MM_FROUND_TO_NEAREST_INT)))
x, _mm_castsi128_ps(_mm256_cvtps_ph(d, _MM_FROUND_TO_NEAREST_INT)))
#endif
#endif
...
@@ -83,25 +83,19 @@ union AVX_Data {
...
@@ -83,25 +83,19 @@ union AVX_Data {
#endif
#endif
#define STEP(SPAN) \
#define STEP(SPAN)
\
void Step_##SPAN(float *_params, float *grads, float *_exp_avg, \
void Step_##SPAN(float *_params, float *grads, float *_exp_avg,
\
float *_exp_avg_sq, size_t _param_size, \
float *_exp_avg_sq, size_t _param_size,
\
bool param_half_precision = false, \
bool param_half_precision = false,
\
bool grad_half_precision = false, float loss_scale = -1);
bool grad_half_precision = false, float loss_scale = -1);
class
Adam_Optimizer
{
class
Adam_Optimizer
{
public:
public:
Adam_Optimizer
(
float
alpha
=
1e-3
,
float
betta1
=
0.9
,
float
betta2
=
0.999
,
Adam_Optimizer
(
float
alpha
=
1e-3
,
float
betta1
=
0.9
,
float
betta2
=
0.999
,
float
eps
=
1e-8
,
float
weight_decay
=
0
,
float
eps
=
1e-8
,
float
weight_decay
=
0
,
bool
adamw_mode
=
true
)
bool
adamw_mode
=
true
)
:
_alpha
(
alpha
),
:
_alpha
(
alpha
),
_betta1
(
betta1
),
_betta2
(
betta2
),
_eps
(
eps
),
_betta1
(
betta1
),
_weight_decay
(
weight_decay
),
_betta1_t
(
1.0
),
_betta2_t
(
1.0
),
_step
(
0
),
_betta2
(
betta2
),
_eps
(
eps
),
_weight_decay
(
weight_decay
),
_betta1_t
(
1.0
),
_betta2_t
(
1.0
),
_step
(
0
),
_adamw_mode
(
adamw_mode
)
{}
_adamw_mode
(
adamw_mode
)
{}
~
Adam_Optimizer
()
{}
~
Adam_Optimizer
()
{}
...
@@ -141,7 +135,7 @@ class Adam_Optimizer {
...
@@ -141,7 +135,7 @@ class Adam_Optimizer {
}
}
}
}
private:
private:
float
_alpha
;
float
_alpha
;
float
_betta1
;
float
_betta1
;
float
_betta2
;
float
_betta2
;
...
...
colossalai/kernel/cuda_native/csrc/kernels/cross_entropy.cu
View file @
58580b50
...
@@ -16,7 +16,7 @@ __global__ void ls_cross_entropy_fw_kernel(
...
@@ -16,7 +16,7 @@ __global__ void ls_cross_entropy_fw_kernel(
const
int
left_idx
=
block_start
+
threadIdx
.
x
;
const
int
left_idx
=
block_start
+
threadIdx
.
x
;
const
int
right_idx
=
(
blockIdx
.
x
+
1
)
*
vocab_size
;
const
int
right_idx
=
(
blockIdx
.
x
+
1
)
*
vocab_size
;
float
max_input
[
1
]
=
{
REDUCE_FLOAT_INF_NEG
};
float
max_input
[
1
]
=
{
REDUCE_FLOAT_INF_NEG
};
float
sum_logits
[
2
]
=
{
0.
f
,
0.
f
};
// logit and logit exp
float
sum_logits
[
2
]
=
{
0.
f
,
0.
f
};
// logit and logit exp
int
target_tid
=
targets
[
blockIdx
.
x
];
int
target_tid
=
targets
[
blockIdx
.
x
];
if
(
target_tid
==
padding_idx
)
{
if
(
target_tid
==
padding_idx
)
{
...
...
colossalai/kernel/cuda_native/csrc/kernels/dropout_kernels.cu
View file @
58580b50
#include <cooperative_groups.h>
#include <chrono>
#include <chrono>
#include <ctime>
#include <ctime>
#include "kernels.h"
#include "kernels.h"
#include <cooperative_groups.h>
namespace
cg
=
cooperative_groups
;
namespace
cg
=
cooperative_groups
;
curandStatePhilox4_32_10_t
*
curandstate
;
curandStatePhilox4_32_10_t
*
curandstate
;
...
@@ -165,7 +165,8 @@ __global__ void ls_dropout_kernel(const int total_count, const float ratio,
...
@@ -165,7 +165,8 @@ __global__ void ls_dropout_kernel(const int total_count, const float ratio,
const
float
scale
=
1.
f
/
(
1.
f
-
ratio
);
const
float
scale
=
1.
f
/
(
1.
f
-
ratio
);
int
i
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
int
i
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
i
*
4
>=
total_count
)
return
;
if
(
i
*
4
>=
total_count
)
return
;
curandStatePhilox4_32_10_t
state
;
curandStatePhilox4_32_10_t
state
;
curand_init
(
seed
,
i
,
0
,
&
state
);
curand_init
(
seed
,
i
,
0
,
&
state
);
...
@@ -201,7 +202,8 @@ __global__ void ls_dropout_kernel(const int total_count, const float ratio,
...
@@ -201,7 +202,8 @@ __global__ void ls_dropout_kernel(const int total_count, const float ratio,
int
i
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
int
i
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
i
*
8
>=
total_count
)
return
;
if
(
i
*
8
>=
total_count
)
return
;
curandStatePhilox4_32_10_t
state
;
curandStatePhilox4_32_10_t
state
;
curand_init
(
seed
,
i
,
0
,
&
state
);
curand_init
(
seed
,
i
,
0
,
&
state
);
...
@@ -259,7 +261,8 @@ __global__ void ls_dropout_bwd_kernel(const int total_count, const float ratio,
...
@@ -259,7 +261,8 @@ __global__ void ls_dropout_bwd_kernel(const int total_count, const float ratio,
const
float
scale
=
1.
f
/
(
1.
f
-
ratio
);
const
float
scale
=
1.
f
/
(
1.
f
-
ratio
);
int
i
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
int
i
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
i
*
4
>=
total_count
)
return
;
if
(
i
*
4
>=
total_count
)
return
;
uint8_t
m
[
4
];
uint8_t
m
[
4
];
...
@@ -286,7 +289,8 @@ __global__ void ls_dropout_bwd_kernel(const int total_count, const float ratio,
...
@@ -286,7 +289,8 @@ __global__ void ls_dropout_bwd_kernel(const int total_count, const float ratio,
int
i
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
int
i
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
i
*
8
>=
total_count
)
return
;
if
(
i
*
8
>=
total_count
)
return
;
float4
*
out4
=
reinterpret_cast
<
float4
*>
(
out
);
float4
*
out4
=
reinterpret_cast
<
float4
*>
(
out
);
const
float4
*
vals_float4
=
reinterpret_cast
<
const
float4
*>
(
in
);
const
float4
*
vals_float4
=
reinterpret_cast
<
const
float4
*>
(
in
);
...
@@ -376,7 +380,8 @@ __global__ void ls_dropout_res_bias_kernel(
...
@@ -376,7 +380,8 @@ __global__ void ls_dropout_res_bias_kernel(
const
float
scale
=
1.
f
/
(
1.
f
-
ratio
);
const
float
scale
=
1.
f
/
(
1.
f
-
ratio
);
int
i
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
int
i
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
i
*
4
>=
total_count
)
return
;
if
(
i
*
4
>=
total_count
)
return
;
curandStatePhilox4_32_10_t
state
;
curandStatePhilox4_32_10_t
state
;
curand_init
(
seed
,
i
,
0
,
&
state
);
curand_init
(
seed
,
i
,
0
,
&
state
);
...
@@ -419,7 +424,8 @@ __global__ void ls_dropout_res_bias_kernel(
...
@@ -419,7 +424,8 @@ __global__ void ls_dropout_res_bias_kernel(
int
i
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
int
i
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
i
*
8
>=
total_count
)
return
;
if
(
i
*
8
>=
total_count
)
return
;
curandStatePhilox4_32_10_t
state
;
curandStatePhilox4_32_10_t
state
;
curand_init
(
seed
,
i
,
0
,
&
state
);
curand_init
(
seed
,
i
,
0
,
&
state
);
...
@@ -559,9 +565,11 @@ __global__ void ls_dropout_bias_bwd_kernel(
...
@@ -559,9 +565,11 @@ __global__ void ls_dropout_bias_bwd_kernel(
}
}
__syncthreads
();
__syncthreads
();
for
(
int
i
=
1
;
i
<
32
;
i
<<=
1
)
sum
+=
g
.
shfl_down
(
sum
,
i
);
for
(
int
i
=
1
;
i
<
32
;
i
<<=
1
)
sum
+=
g
.
shfl_down
(
sum
,
i
);
if
(
y
==
0
)
tile
[
0
][
x
]
=
sum
;
if
(
y
==
0
)
tile
[
0
][
x
]
=
sum
;
__syncthreads
();
__syncthreads
();
if
(
threadIdx
.
x
<
8
)
{
if
(
threadIdx
.
x
<
8
)
{
...
@@ -613,9 +621,11 @@ __global__ void ls_dropout_bias_bwd_kernel(
...
@@ -613,9 +621,11 @@ __global__ void ls_dropout_bias_bwd_kernel(
}
}
__syncthreads
();
__syncthreads
();
for
(
int
i
=
1
;
i
<
WARP_SIZE
;
i
<<=
1
)
sum
+=
g
.
shfl_down
(
sum
,
i
);
for
(
int
i
=
1
;
i
<
WARP_SIZE
;
i
<<=
1
)
sum
+=
g
.
shfl_down
(
sum
,
i
);
if
(
y
==
0
)
tile
[
0
][
x
]
=
sum
;
if
(
y
==
0
)
tile
[
0
][
x
]
=
sum
;
__syncthreads
();
__syncthreads
();
if
(
threadIdx
.
x
<
8
)
{
if
(
threadIdx
.
x
<
8
)
{
...
@@ -679,7 +689,8 @@ __global__ void ls_dropout_act_bias_kernel(
...
@@ -679,7 +689,8 @@ __global__ void ls_dropout_act_bias_kernel(
const
float
scale
=
1.
f
/
(
1.
f
-
ratio
);
const
float
scale
=
1.
f
/
(
1.
f
-
ratio
);
int
i
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
int
i
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
i
*
4
>=
total_count
)
return
;
if
(
i
*
4
>=
total_count
)
return
;
curandStatePhilox4_32_10_t
state
;
curandStatePhilox4_32_10_t
state
;
curand_init
(
seed
,
i
,
0
,
&
state
);
curand_init
(
seed
,
i
,
0
,
&
state
);
...
@@ -724,7 +735,8 @@ __global__ void ls_dropout_act_bias_kernel(
...
@@ -724,7 +735,8 @@ __global__ void ls_dropout_act_bias_kernel(
int
i
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
int
i
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
i
*
8
>=
total_count
)
return
;
if
(
i
*
8
>=
total_count
)
return
;
curandStatePhilox4_32_10_t
state
;
curandStatePhilox4_32_10_t
state
;
curand_init
(
seed
,
i
,
0
,
&
state
);
curand_init
(
seed
,
i
,
0
,
&
state
);
...
@@ -885,9 +897,11 @@ __global__ void ls_dropout_act_bias_bwd_kernel(
...
@@ -885,9 +897,11 @@ __global__ void ls_dropout_act_bias_bwd_kernel(
float
sum
=
tile
[
threadIdx
.
y
][
threadIdx
.
x
];
float
sum
=
tile
[
threadIdx
.
y
][
threadIdx
.
x
];
__syncthreads
();
__syncthreads
();
for
(
int
i
=
1
;
i
<
WARP_SIZE
;
i
<<=
1
)
sum
+=
g
.
shfl_down
(
sum
,
i
);
for
(
int
i
=
1
;
i
<
WARP_SIZE
;
i
<<=
1
)
sum
+=
g
.
shfl_down
(
sum
,
i
);
if
(
threadIdx
.
x
==
0
)
tile
[
0
][
threadIdx
.
y
]
=
sum
;
if
(
threadIdx
.
x
==
0
)
tile
[
0
][
threadIdx
.
y
]
=
sum
;
__syncthreads
();
__syncthreads
();
if
(
threadIdx
.
y
==
0
)
{
if
(
threadIdx
.
y
==
0
)
{
...
...
colossalai/kernel/cuda_native/csrc/kernels/general_kernels.cu
View file @
58580b50
#include <cooperative_groups.h>
#include "kernels.h"
#include "kernels.h"
#include <cooperative_groups.h>
namespace
cg
=
cooperative_groups
;
namespace
cg
=
cooperative_groups
;
/**
/**
...
...
colossalai/kernel/cuda_native/csrc/kernels/include/block_reduce.h
View file @
58580b50
...
@@ -13,23 +13,22 @@ const float REDUCE_FLOAT_INF_NEG = -100000000.f;
...
@@ -13,23 +13,22 @@ const float REDUCE_FLOAT_INF_NEG = -100000000.f;
const
float
REDUCE_FLOAT_INF_POS
=
100000000.
f
;
const
float
REDUCE_FLOAT_INF_POS
=
100000000.
f
;
const
unsigned
int
WARP_REDUCE_SIZE
=
32
;
const
unsigned
int
WARP_REDUCE_SIZE
=
32
;
template
<
typename
T
>
template
<
typename
T
>
__forceinline__
__device__
T
warpReduceSum
(
T
val
)
{
__forceinline__
__device__
T
warpReduceSum
(
T
val
)
{
for
(
int
mask
=
(
WARP_REDUCE_SIZE
>>
1
);
mask
>
0
;
mask
>>=
1
)
for
(
int
mask
=
(
WARP_REDUCE_SIZE
>>
1
);
mask
>
0
;
mask
>>=
1
)
val
+=
__shfl_xor_sync
(
WARP_REDUCE_MASK
,
val
,
mask
,
WARP_REDUCE_SIZE
);
val
+=
__shfl_xor_sync
(
WARP_REDUCE_MASK
,
val
,
mask
,
WARP_REDUCE_SIZE
);
return
val
;
return
val
;
}
}
/* Calculate the sum of all elements in a block */
/* Calculate the sum of all elements in a block */
template
<
typename
T
>
template
<
typename
T
>
__forceinline__
__device__
T
blockReduceSum
(
T
val
)
{
__forceinline__
__device__
T
blockReduceSum
(
T
val
)
{
static
__shared__
T
shared
[
32
];
static
__shared__
T
shared
[
32
];
int
lane
=
threadIdx
.
x
&
0x1f
;
int
lane
=
threadIdx
.
x
&
0x1f
;
int
wid
=
threadIdx
.
x
>>
5
;
int
wid
=
threadIdx
.
x
>>
5
;
val
=
warpReduceSum
<
T
>
(
val
);
val
=
warpReduceSum
<
T
>
(
val
);
if
(
lane
==
0
)
shared
[
wid
]
=
val
;
if
(
lane
==
0
)
shared
[
wid
]
=
val
;
__syncthreads
();
__syncthreads
();
val
=
(
threadIdx
.
x
<
(
blockDim
.
x
>>
5
))
?
shared
[
lane
]
:
(
T
)
0.0
f
;
val
=
(
threadIdx
.
x
<
(
blockDim
.
x
>>
5
))
?
shared
[
lane
]
:
(
T
)
0.0
f
;
...
@@ -57,10 +56,10 @@ __inline__ __device__ void warpReduce<ReduceType::kMax, 1>(float *pval) {
...
@@ -57,10 +56,10 @@ __inline__ __device__ void warpReduce<ReduceType::kMax, 1>(float *pval) {
template
<
>
template
<
>
__inline__
__device__
void
warpReduce
<
ReduceType
::
kMax
,
2
>
(
float
*
pval
)
{
__inline__
__device__
void
warpReduce
<
ReduceType
::
kMax
,
2
>
(
float
*
pval
)
{
float
val0_tmp
,
val1_tmp
;
float
val0_tmp
,
val1_tmp
;
#define WarpReduceMaxOneStep(a, b) \
#define WarpReduceMaxOneStep(a, b)
\
val0_tmp = __shfl_xor_sync(WARP_REDUCE_MASK, *(pval), a, b); \
val0_tmp = __shfl_xor_sync(WARP_REDUCE_MASK, *(pval), a, b);
\
val1_tmp = __shfl_xor_sync(WARP_REDUCE_MASK, *(pval + 1), a, b); \
val1_tmp = __shfl_xor_sync(WARP_REDUCE_MASK, *(pval + 1), a, b);
\
*(pval) = max(val0_tmp, *(pval)); \
*(pval) = max(val0_tmp, *(pval));
\
*(pval + 1) = max(val1_tmp, *(pval + 1));
*(pval + 1) = max(val1_tmp, *(pval + 1));
WarpReduceMaxOneStep
(
16
,
32
);
WarpReduceMaxOneStep
(
16
,
32
);
...
@@ -89,10 +88,10 @@ __inline__ __device__ void warpReduce<ReduceType::kSum, 1>(float *pval) {
...
@@ -89,10 +88,10 @@ __inline__ __device__ void warpReduce<ReduceType::kSum, 1>(float *pval) {
template
<
>
template
<
>
__inline__
__device__
void
warpReduce
<
ReduceType
::
kSum
,
2
>
(
float
*
pval
)
{
__inline__
__device__
void
warpReduce
<
ReduceType
::
kSum
,
2
>
(
float
*
pval
)
{
float
val0_tmp
,
val1_tmp
;
float
val0_tmp
,
val1_tmp
;
#define WarpReduceSumOneStep(a, b) \
#define WarpReduceSumOneStep(a, b)
\
val0_tmp = __shfl_xor_sync(WARP_REDUCE_MASK, *(pval + 0), a, b); \
val0_tmp = __shfl_xor_sync(WARP_REDUCE_MASK, *(pval + 0), a, b);
\
val1_tmp = __shfl_xor_sync(WARP_REDUCE_MASK, *(pval + 1), a, b); \
val1_tmp = __shfl_xor_sync(WARP_REDUCE_MASK, *(pval + 1), a, b);
\
*(pval + 0) += val0_tmp; \
*(pval + 0) += val0_tmp;
\
*(pval + 1) += val1_tmp
*(pval + 1) += val1_tmp
WarpReduceSumOneStep
(
16
,
32
);
WarpReduceSumOneStep
(
16
,
32
);
...
@@ -107,14 +106,14 @@ __inline__ __device__ void warpReduce<ReduceType::kSum, 2>(float *pval) {
...
@@ -107,14 +106,14 @@ __inline__ __device__ void warpReduce<ReduceType::kSum, 2>(float *pval) {
template
<
>
template
<
>
__inline__
__device__
void
warpReduce
<
ReduceType
::
kSum
,
4
>
(
float
*
pval
)
{
__inline__
__device__
void
warpReduce
<
ReduceType
::
kSum
,
4
>
(
float
*
pval
)
{
float
val0_tmp
,
val1_tmp
,
val2_tmp
,
val3_tmp
;
float
val0_tmp
,
val1_tmp
,
val2_tmp
,
val3_tmp
;
#define WarpReduceSumOneStep(a, b) \
#define WarpReduceSumOneStep(a, b)
\
val0_tmp = __shfl_xor_sync(WARP_REDUCE_MASK, *(pval + 0), a, b); \
val0_tmp = __shfl_xor_sync(WARP_REDUCE_MASK, *(pval + 0), a, b);
\
val1_tmp = __shfl_xor_sync(WARP_REDUCE_MASK, *(pval + 1), a, b); \
val1_tmp = __shfl_xor_sync(WARP_REDUCE_MASK, *(pval + 1), a, b);
\
val2_tmp = __shfl_xor_sync(WARP_REDUCE_MASK, *(pval + 2), a, b); \
val2_tmp = __shfl_xor_sync(WARP_REDUCE_MASK, *(pval + 2), a, b);
\
val3_tmp = __shfl_xor_sync(WARP_REDUCE_MASK, *(pval + 3), a, b); \
val3_tmp = __shfl_xor_sync(WARP_REDUCE_MASK, *(pval + 3), a, b);
\
*(pval + 0) += val0_tmp; \
*(pval + 0) += val0_tmp;
\
*(pval + 1) += val1_tmp; \
*(pval + 1) += val1_tmp;
\
*(pval + 2) += val2_tmp; \
*(pval + 2) += val2_tmp;
\
*(pval + 3) += val3_tmp
*(pval + 3) += val3_tmp
WarpReduceSumOneStep
(
16
,
32
);
WarpReduceSumOneStep
(
16
,
32
);
...
...
colossalai/kernel/cuda_native/csrc/kernels/include/context.h
View file @
58580b50
...
@@ -9,7 +9,7 @@
...
@@ -9,7 +9,7 @@
#include "cuda_util.h"
#include "cuda_util.h"
class
Context
{
class
Context
{
public:
public:
Context
()
:
_stream
(
nullptr
)
{
Context
()
:
_stream
(
nullptr
)
{
CHECK_GPU_ERROR
(
cublasCreate
(
&
_cublasHandle
));
CHECK_GPU_ERROR
(
cublasCreate
(
&
_cublasHandle
));
}
}
...
@@ -30,7 +30,7 @@ class Context {
...
@@ -30,7 +30,7 @@ class Context {
cublasHandle_t
get_cublashandle
()
{
return
_cublasHandle
;
}
cublasHandle_t
get_cublashandle
()
{
return
_cublasHandle
;
}
private:
private:
cudaStream_t
_stream
;
cudaStream_t
_stream
;
cublasHandle_t
_cublasHandle
;
cublasHandle_t
_cublasHandle
;
};
};
colossalai/kernel/cuda_native/csrc/kernels/include/cross_entropy_layer.h
View file @
58580b50
...
@@ -8,9 +8,8 @@
...
@@ -8,9 +8,8 @@
#include "cuda_util.h"
#include "cuda_util.h"
template
<
typename
T
>
template
<
typename
T
>
class
CrossEntropyLayer
{
class
CrossEntropyLayer
{
public:
public:
CrossEntropyLayer
(
float
epsilon
,
int
padding_idx
,
int
max_batch_tokens
);
CrossEntropyLayer
(
float
epsilon
,
int
padding_idx
,
int
max_batch_tokens
);
virtual
~
CrossEntropyLayer
();
virtual
~
CrossEntropyLayer
();
...
@@ -23,7 +22,7 @@ class CrossEntropyLayer {
...
@@ -23,7 +22,7 @@ class CrossEntropyLayer {
void
set_cur_batch_shape
(
int
batch_size
,
int
seq_len
,
int
vocab_size
);
void
set_cur_batch_shape
(
int
batch_size
,
int
seq_len
,
int
vocab_size
);
private:
private:
void
allocate_mem_buffer
()
{
void
allocate_mem_buffer
()
{
// allocate local gpu memory
// allocate local gpu memory
_loss_buffer
=
cuda_malloc
<
float
>
(
_max_batch_tokens
*
2
);
_loss_buffer
=
cuda_malloc
<
float
>
(
_max_batch_tokens
*
2
);
...
...
colossalai/kernel/cuda_native/csrc/kernels/include/cuda_util.h
View file @
58580b50
...
@@ -20,8 +20,7 @@ void check_gpu_error(T result, char const *const func, const char *const file,
...
@@ -20,8 +20,7 @@ void check_gpu_error(T result, char const *const func, const char *const file,
template
<
typename
T
>
template
<
typename
T
>
void
print_vec
(
const
T
*
outv
,
std
::
string
outn
,
int
num_output_ele
);
void
print_vec
(
const
T
*
outv
,
std
::
string
outn
,
int
num_output_ele
);
template
<
typename
T
>
template
<
typename
T
>
T
*
cuda_malloc
(
size_t
ele_num
);
T
*
cuda_malloc
(
size_t
ele_num
);
void
cuda_free
(
void
*
pdata
);
void
cuda_free
(
void
*
pdata
);
...
@@ -29,6 +28,6 @@ template <typename T>
...
@@ -29,6 +28,6 @@ template <typename T>
void
check_nan_inf
(
const
T
*
data_ptr
,
int
dsize
,
bool
check_nan_inf
,
void
check_nan_inf
(
const
T
*
data_ptr
,
int
dsize
,
bool
check_nan_inf
,
std
::
string
file
,
int
line
,
cudaStream_t
stream
);
std
::
string
file
,
int
line
,
cudaStream_t
stream
);
#define CHECK_NAN_INF(ptr, size, stream) \
#define CHECK_NAN_INF(ptr, size, stream)
\
check_nan_inf((ptr), (size), true, __FILE__, __LINE__, (stream)); \
check_nan_inf((ptr), (size), true, __FILE__, __LINE__, (stream));
\
check_nan_inf((ptr), (size), false, __FILE__, __LINE__, (stream))
check_nan_inf((ptr), (size), false, __FILE__, __LINE__, (stream))
colossalai/kernel/cuda_native/csrc/kernels/include/dropout.h
View file @
58580b50
...
@@ -3,14 +3,12 @@
...
@@ -3,14 +3,12 @@
#include <cuda.h>
#include <cuda.h>
#include <cuda_fp16.h>
#include <cuda_fp16.h>
#include <stdio.h>
#include <stdio.h>
#include <string>
#include <string>
#include "kernels.h"
#include "kernels.h"
template
<
typename
T
>
template
<
typename
T
>
class
Dropout
{
class
Dropout
{
public:
public:
struct
Config
{
struct
Config
{
float
ratio
;
float
ratio
;
bool
training
;
bool
training
;
...
@@ -90,7 +88,7 @@ class Dropout {
...
@@ -90,7 +88,7 @@ class Dropout {
void
SetTrainingMode
(
bool
training
)
{
_config
.
training
=
training
;
}
void
SetTrainingMode
(
bool
training
)
{
_config
.
training
=
training
;
}
private:
private:
uint8_t
*
_mask
;
uint8_t
*
_mask
;
Config
_config
;
Config
_config
;
};
};
colossalai/kernel/cuda_native/csrc/kernels/include/feed_forward.h
View file @
58580b50
...
@@ -13,16 +13,14 @@
...
@@ -13,16 +13,14 @@
#include "cublas_wrappers.h"
#include "cublas_wrappers.h"
#include "kernels.h"
#include "kernels.h"
template
<
typename
T
>
template
<
typename
T
>
class
FeedForward
{
class
FeedForward
{
public:
public:
struct
Config
{
struct
Config
{
int
outputSize
;
int
outputSize
;
int
inputSize
;
int
inputSize
;
std
::
array
<
int
,
3
>
gemm_algos
;
std
::
array
<
int
,
3
>
gemm_algos
;
Config
(
int
outputs
,
int
inputs
)
Config
(
int
outputs
,
int
inputs
)
:
outputSize
(
outputs
),
:
outputSize
(
outputs
),
inputSize
(
inputs
),
inputSize
(
inputs
),
gemm_algos
(
std
::
array
<
int
,
3
>
({
99
,
99
,
99
}))
{}
gemm_algos
(
std
::
array
<
int
,
3
>
({
99
,
99
,
99
}))
{}
};
};
...
@@ -63,6 +61,6 @@ class FeedForward {
...
@@ -63,6 +61,6 @@ class FeedForward {
config_
.
inputSize
=
inputSize
;
config_
.
inputSize
=
inputSize
;
}
}
private:
private:
Config
config_
;
Config
config_
;
};
};
colossalai/kernel/cuda_native/csrc/kernels/include/softmax.h
View file @
58580b50
...
@@ -10,9 +10,8 @@
...
@@ -10,9 +10,8 @@
using
namespace
std
;
using
namespace
std
;
template
<
typename
T
>
template
<
typename
T
>
class
Softmax
{
class
Softmax
{
public:
public:
struct
Config
{
struct
Config
{
size_t
nhead
;
size_t
nhead
;
Config
(
size_t
nhead
)
:
nhead
(
nhead
)
{}
Config
(
size_t
nhead
)
:
nhead
(
nhead
)
{}
...
@@ -37,6 +36,6 @@ class Softmax {
...
@@ -37,6 +36,6 @@ class Softmax {
void
reset_size
(
size_t
nhead
)
{
config_
.
nhead
=
nhead
;
}
void
reset_size
(
size_t
nhead
)
{
config_
.
nhead
=
nhead
;
}
private:
private:
Config
config_
;
Config
config_
;
};
};
colossalai/kernel/cuda_native/csrc/kernels/normalize_kernels.cu
View file @
58580b50
#include "block_reduce.h"
#include "block_reduce.h"
#include "kernels.h"
#include "kernels.h"
#include <cooperative_groups.h>
#include <cooperative_groups.h>
namespace
cg
=
cooperative_groups
;
namespace
cg
=
cooperative_groups
;
...
...
colossalai/kernel/cuda_native/csrc/kernels/softmax_kernels.cu
View file @
58580b50
#include <cooperative_groups.h>
#include <math.h>
#include <math.h>
#include <cub/block/block_load.cuh>
#include <cub/block/block_load.cuh>
...
@@ -7,6 +6,8 @@
...
@@ -7,6 +6,8 @@
#include "block_reduce.h"
#include "block_reduce.h"
#include "kernels.h"
#include "kernels.h"
#include <cooperative_groups.h>
namespace
cg
=
cooperative_groups
;
namespace
cg
=
cooperative_groups
;
const
float
EPSILON
=
1e-8
f
;
const
float
EPSILON
=
1e-8
f
;
...
@@ -119,7 +120,7 @@ __global__ void ker_attn_softmax(T *inp, const T *attn_mask, int from_len,
...
@@ -119,7 +120,7 @@ __global__ void ker_attn_softmax(T *inp, const T *attn_mask, int from_len,
BlockStore
(
ts_store
).
Store
(
inp
+
(
token_id
+
i
)
*
to_len
,
inp_val
[
i
],
BlockStore
(
ts_store
).
Store
(
inp
+
(
token_id
+
i
)
*
to_len
,
inp_val
[
i
],
to_len
);
to_len
);
}
}
}
// blockIdx.x
}
// blockIdx.x
}
}
template
<
typename
T
,
int
block_dim
,
int
ele_per_thread
>
template
<
typename
T
,
int
block_dim
,
int
ele_per_thread
>
...
@@ -197,7 +198,7 @@ __global__ void ker_attn_softmax_lt32(T *inp, const T *attn_mask, int from_len,
...
@@ -197,7 +198,7 @@ __global__ void ker_attn_softmax_lt32(T *inp, const T *attn_mask, int from_len,
BlockStore
(
ts_store
).
Store
(
inp
+
(
token_id
+
i
)
*
to_len
,
inp_val
[
i
],
BlockStore
(
ts_store
).
Store
(
inp
+
(
token_id
+
i
)
*
to_len
,
inp_val
[
i
],
to_len
);
to_len
);
}
}
}
// blockIdx.x
}
// blockIdx.x
}
}
/*
/*
...
@@ -303,7 +304,8 @@ __global__ void ker_attn_softmax_bw(T *grad, const T *inp, int softmax_length) {
...
@@ -303,7 +304,8 @@ __global__ void ker_attn_softmax_bw(T *grad, const T *inp, int softmax_length) {
cg
::
thread_block
b
=
cg
::
this_thread_block
();
cg
::
thread_block
b
=
cg
::
this_thread_block
();
cg
::
thread_block_tile
<
WARP_SIZE
>
g
=
cg
::
tiled_partition
<
WARP_SIZE
>
(
b
);
cg
::
thread_block_tile
<
WARP_SIZE
>
g
=
cg
::
tiled_partition
<
WARP_SIZE
>
(
b
);
for
(
int
i
=
1
;
i
<
WARP_SIZE
;
i
<<=
1
)
sum
+=
g
.
shfl_xor
(
sum
,
i
);
for
(
int
i
=
1
;
i
<
WARP_SIZE
;
i
<<=
1
)
sum
+=
g
.
shfl_xor
(
sum
,
i
);
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
ITERATIONS
;
++
i
)
{
for
(
int
i
=
0
;
i
<
ITERATIONS
;
++
i
)
{
...
...
colossalai/kernel/cuda_native/csrc/layer_norm_cuda.cpp
View file @
58580b50
...
@@ -2,13 +2,11 @@
...
@@ -2,13 +2,11 @@
* https://github.com/NVIDIA/apex
* https://github.com/NVIDIA/apex
* with minor changes. */
* with minor changes. */
#include <torch/extension.h>
#include "compat.h"
#include <cassert>
#include <cassert>
#include <torch/extension.h>
#include <vector>
#include <vector>
#include "compat.h"
namespace
{
namespace
{
void
compute_n1_n2
(
at
::
Tensor
input
,
at
::
IntArrayRef
normalized_shape
,
int
&
n1
,
void
compute_n1_n2
(
at
::
Tensor
input
,
at
::
IntArrayRef
normalized_shape
,
int
&
n1
,
...
@@ -67,7 +65,7 @@ void check_args(at::Tensor input, at::IntArrayRef normalized_shape,
...
@@ -67,7 +65,7 @@ void check_args(at::Tensor input, at::IntArrayRef normalized_shape,
check_args
(
input
,
normalized_shape
,
n1
,
n2
);
check_args
(
input
,
normalized_shape
,
n1
,
n2
);
check_args
(
normalized_shape
,
gamma
,
beta
);
check_args
(
normalized_shape
,
gamma
,
beta
);
}
}
}
// namespace
}
// namespace
void
cuda_layer_norm
(
at
::
Tensor
*
output
,
at
::
Tensor
*
mean
,
at
::
Tensor
*
invvar
,
void
cuda_layer_norm
(
at
::
Tensor
*
output
,
at
::
Tensor
*
mean
,
at
::
Tensor
*
invvar
,
at
::
Tensor
*
input
,
int
n1
,
int
n2
,
at
::
Tensor
*
input
,
int
n1
,
int
n2
,
...
@@ -75,16 +73,17 @@ void cuda_layer_norm(at::Tensor *output, at::Tensor *mean, at::Tensor *invvar,
...
@@ -75,16 +73,17 @@ void cuda_layer_norm(at::Tensor *output, at::Tensor *mean, at::Tensor *invvar,
at
::
Tensor
*
beta
,
double
epsilon
);
at
::
Tensor
*
beta
,
double
epsilon
);
#define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) \
#define CHECK_CONTIGUOUS(x)
\
TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) \
#define CHECK_INPUT(x)
\
CHECK_CUDA(x); \
CHECK_CUDA(x);
\
CHECK_CONTIGUOUS(x)
CHECK_CONTIGUOUS(x)
std
::
vector
<
at
::
Tensor
>
layer_norm_affine
(
at
::
Tensor
input
,
std
::
vector
<
at
::
Tensor
>
layer_norm_affine
(
at
::
Tensor
input
,
at
::
IntArrayRef
normalized_shape
,
at
::
IntArrayRef
normalized_shape
,
at
::
Tensor
gamma
,
at
::
Tensor
beta
,
at
::
Tensor
gamma
,
at
::
Tensor
beta
,
double
epsilon
)
{
double
epsilon
)
{
CHECK_INPUT
(
input
);
CHECK_INPUT
(
input
);
CHECK_INPUT
(
gamma
);
CHECK_INPUT
(
gamma
);
CHECK_INPUT
(
beta
);
CHECK_INPUT
(
beta
);
...
@@ -110,10 +109,11 @@ void cuda_layer_norm_gradient(at::Tensor *dout, at::Tensor *mean,
...
@@ -110,10 +109,11 @@ void cuda_layer_norm_gradient(at::Tensor *dout, at::Tensor *mean,
double
epsilon
,
at
::
Tensor
*
grad_input
,
double
epsilon
,
at
::
Tensor
*
grad_input
,
at
::
Tensor
*
grad_gamma
,
at
::
Tensor
*
grad_beta
);
at
::
Tensor
*
grad_gamma
,
at
::
Tensor
*
grad_beta
);
std
::
vector
<
at
::
Tensor
>
layer_norm_gradient_affine
(
std
::
vector
<
at
::
Tensor
>
at
::
Tensor
dout
,
at
::
Tensor
mean
,
at
::
Tensor
invvar
,
at
::
Tensor
input
,
layer_norm_gradient_affine
(
at
::
Tensor
dout
,
at
::
Tensor
mean
,
at
::
Tensor
invvar
,
at
::
IntArrayRef
normalized_shape
,
at
::
Tensor
gamma
,
at
::
Tensor
beta
,
at
::
Tensor
input
,
at
::
IntArrayRef
normalized_shape
,
double
epsilon
)
{
at
::
Tensor
gamma
,
at
::
Tensor
beta
,
double
epsilon
)
{
CHECK_INPUT
(
dout
);
CHECK_INPUT
(
dout
);
CHECK_INPUT
(
mean
);
CHECK_INPUT
(
mean
);
CHECK_INPUT
(
invvar
);
CHECK_INPUT
(
invvar
);
...
...
colossalai/kernel/cuda_native/csrc/moe_cuda.cpp
View file @
58580b50
...
@@ -15,24 +15,25 @@ torch::Tensor moe_combine_cuda_forward(int s, int e, int c, int h,
...
@@ -15,24 +15,25 @@ torch::Tensor moe_combine_cuda_forward(int s, int e, int c, int h,
torch
::
Tensor
logits
,
torch
::
Tensor
mask
,
torch
::
Tensor
logits
,
torch
::
Tensor
mask
,
torch
::
Tensor
dest_idx
);
torch
::
Tensor
dest_idx
);
std
::
vector
<
torch
::
Tensor
>
moe_combine_cuda_backward
(
std
::
vector
<
torch
::
Tensor
>
int
s
,
int
e
,
int
c
,
int
h
,
torch
::
Tensor
tokens_grad
,
moe_combine_cuda_backward
(
int
s
,
int
e
,
int
c
,
int
h
,
torch
::
Tensor
tokens_grad
,
torch
::
Tensor
expert_tokens
,
torch
::
Tensor
logits
,
torch
::
Tensor
mask
,
torch
::
Tensor
expert_tokens
,
torch
::
Tensor
logits
,
torch
::
Tensor
dest_idx
);
torch
::
Tensor
mask
,
torch
::
Tensor
dest_idx
);
torch
::
Tensor
cumsum_sub_one_in_dim0
(
torch
::
Tensor
mask
);
torch
::
Tensor
cumsum_sub_one_in_dim0
(
torch
::
Tensor
mask
);
#define CHECK_CUDA(x) \
#define CHECK_CUDA(x)
\
TORCH_CHECK
(
x
.
device
().
is_cuda
(),
#
x
" must be a CUDA tensor"
)
TORCH_CHECK
(
x
.
device
().
is_cuda
(),
#
x
" must be a CUDA tensor"
)
#define CHECK_CONTIGUOUS(x) \
#define CHECK_CONTIGUOUS(x)
\
TORCH_CHECK
(
x
.
is_contiguous
(),
#
x
" must be contiguous"
)
TORCH_CHECK
(
x
.
is_contiguous
(),
#
x
" must be contiguous"
)
#define CHECK_INPUT(x) \
#define CHECK_INPUT(x)
\
CHECK_CUDA
(
x
);
\
CHECK_CUDA
(
x
);
\
CHECK_CONTIGUOUS
(
x
)
CHECK_CONTIGUOUS
(
x
)
torch
::
Tensor
moe_dispatch_forward
(
int
s
,
int
ec
,
int
h
,
torch
::
Tensor
moe_dispatch_forward
(
int
s
,
int
ec
,
int
h
,
torch
::
Tensor
batch_tokens
,
torch
::
Tensor
batch_tokens
,
torch
::
Tensor
mask
,
torch
::
Tensor
dest_idx
)
{
torch
::
Tensor
mask
,
torch
::
Tensor
dest_idx
)
{
CHECK_INPUT
(
batch_tokens
);
CHECK_INPUT
(
batch_tokens
);
CHECK_CUDA
(
mask
);
CHECK_CUDA
(
mask
);
CHECK_CUDA
(
dest_idx
);
CHECK_CUDA
(
dest_idx
);
...
@@ -44,6 +45,7 @@ torch::Tensor moe_dispatch_backward(int s, int ec, int h,
...
@@ -44,6 +45,7 @@ torch::Tensor moe_dispatch_backward(int s, int ec, int h,
torch
::
Tensor
expert_grad
,
torch
::
Tensor
expert_grad
,
torch
::
Tensor
mask
,
torch
::
Tensor
mask
,
torch
::
Tensor
dest_idx
)
{
torch
::
Tensor
dest_idx
)
{
CHECK_INPUT
(
expert_grad
);
CHECK_INPUT
(
expert_grad
);
CHECK_CUDA
(
mask
);
CHECK_CUDA
(
mask
);
CHECK_CUDA
(
dest_idx
);
CHECK_CUDA
(
dest_idx
);
...
@@ -55,6 +57,7 @@ torch::Tensor moe_combine_forward(int s, int e, int c, int h,
...
@@ -55,6 +57,7 @@ torch::Tensor moe_combine_forward(int s, int e, int c, int h,
torch
::
Tensor
expert_tokens
,
torch
::
Tensor
expert_tokens
,
torch
::
Tensor
logits
,
torch
::
Tensor
mask
,
torch
::
Tensor
logits
,
torch
::
Tensor
mask
,
torch
::
Tensor
dest_idx
)
{
torch
::
Tensor
dest_idx
)
{
CHECK_INPUT
(
expert_tokens
);
CHECK_INPUT
(
expert_tokens
);
CHECK_INPUT
(
logits
);
CHECK_INPUT
(
logits
);
CHECK_CUDA
(
mask
);
CHECK_CUDA
(
mask
);
...
@@ -64,12 +67,11 @@ torch::Tensor moe_combine_forward(int s, int e, int c, int h,
...
@@ -64,12 +67,11 @@ torch::Tensor moe_combine_forward(int s, int e, int c, int h,
dest_idx
);
dest_idx
);
}
}
std
::
vector
<
torch
::
Tensor
>
moe_combine_backward
(
int
s
,
int
e
,
int
c
,
int
h
,
std
::
vector
<
torch
::
Tensor
>
torch
::
Tensor
tokens_grad
,
moe_combine_backward
(
int
s
,
int
e
,
int
c
,
int
h
,
torch
::
Tensor
tokens_grad
,
torch
::
Tensor
expert_tokens
,
torch
::
Tensor
expert_tokens
,
torch
::
Tensor
logits
,
torch
::
Tensor
logits
,
torch
::
Tensor
mask
,
torch
::
Tensor
dest_idx
)
{
torch
::
Tensor
mask
,
torch
::
Tensor
dest_idx
)
{
CHECK_INPUT
(
tokens_grad
);
CHECK_INPUT
(
tokens_grad
);
CHECK_INPUT
(
logits
);
CHECK_INPUT
(
logits
);
CHECK_CUDA
(
mask
);
CHECK_CUDA
(
mask
);
...
...
colossalai/kernel/cuda_native/csrc/moe_cuda_kernel.cu
View file @
58580b50
#include "block_reduce.h"
#include <cub/cub.cuh>
#include <cuda.h>
#include <cuda.h>
#include <cuda_fp16.h>
#include <cuda_fp16.h>
#include <torch/extension.h>
#include <torch/extension.h>
#include <cub/cub.cuh>
#include "block_reduce.h"
template
<
typename
T
,
int
block_size
,
int
pack_size
>
template
<
typename
T
,
int
block_size
,
int
pack_size
>
__device__
void
moe_dpch_one_fwd
(
T
*
src_row
,
T
*
dst_row
,
const
int
cols
)
{
__device__
void
moe_dpch_one_fwd
(
T
*
src_row
,
T
*
dst_row
,
const
int
cols
)
{
assert
(
cols
%
pack_size
==
0
);
assert
(
cols
%
pack_size
==
0
);
const
int
bpack_size
=
block_size
*
pack_size
;
const
int
bpack_size
=
block_size
*
pack_size
;
...
@@ -29,6 +28,7 @@ __device__ void moe_dpch_one_fwd(T *src_row, T *dst_row, const int cols) {
...
@@ -29,6 +28,7 @@ __device__ void moe_dpch_one_fwd(T *src_row, T *dst_row, const int cols) {
template
<
typename
T
,
int
block_size
,
int
pack_size
>
template
<
typename
T
,
int
block_size
,
int
pack_size
>
__device__
void
moe_dpch_one_bwd
(
T
*
src_row
,
T
*
dst_row
,
const
int
cols
)
{
__device__
void
moe_dpch_one_bwd
(
T
*
src_row
,
T
*
dst_row
,
const
int
cols
)
{
assert
(
cols
%
pack_size
==
0
);
assert
(
cols
%
pack_size
==
0
);
const
int
bpack_size
=
block_size
*
pack_size
;
const
int
bpack_size
=
block_size
*
pack_size
;
...
@@ -51,6 +51,7 @@ __device__ void moe_dpch_one_bwd(T *src_row, T *dst_row, const int cols) {
...
@@ -51,6 +51,7 @@ __device__ void moe_dpch_one_bwd(T *src_row, T *dst_row, const int cols) {
template
<
typename
T
,
int
block_size
,
int
pack_size
>
template
<
typename
T
,
int
block_size
,
int
pack_size
>
__device__
void
moe_dpch_two_fwd
(
T
*
src_row
,
T
*
dst_row1
,
T
*
dst_row2
,
__device__
void
moe_dpch_two_fwd
(
T
*
src_row
,
T
*
dst_row1
,
T
*
dst_row2
,
const
int
cols
)
{
const
int
cols
)
{
assert
(
cols
%
pack_size
==
0
);
assert
(
cols
%
pack_size
==
0
);
const
int
bpack_size
=
block_size
*
pack_size
;
const
int
bpack_size
=
block_size
*
pack_size
;
...
@@ -74,6 +75,7 @@ __device__ void moe_dpch_two_fwd(T *src_row, T *dst_row1, T *dst_row2,
...
@@ -74,6 +75,7 @@ __device__ void moe_dpch_two_fwd(T *src_row, T *dst_row1, T *dst_row2,
template
<
typename
T
,
int
block_size
,
int
pack_size
>
template
<
typename
T
,
int
block_size
,
int
pack_size
>
__device__
void
moe_dpch_two_bwd
(
T
*
src_row
,
T
*
dst_row1
,
T
*
dst_row2
,
__device__
void
moe_dpch_two_bwd
(
T
*
src_row
,
T
*
dst_row1
,
T
*
dst_row2
,
const
int
cols
)
{
const
int
cols
)
{
assert
(
cols
%
pack_size
==
0
);
assert
(
cols
%
pack_size
==
0
);
const
int
bpack_size
=
block_size
*
pack_size
;
const
int
bpack_size
=
block_size
*
pack_size
;
...
@@ -103,6 +105,7 @@ __device__ void moe_dpch_two_bwd(T *src_row, T *dst_row1, T *dst_row2,
...
@@ -103,6 +105,7 @@ __device__ void moe_dpch_two_bwd(T *src_row, T *dst_row1, T *dst_row2,
template
<
typename
T
,
int
block_size
,
int
pack_size
>
template
<
typename
T
,
int
block_size
,
int
pack_size
>
__device__
void
moe_cb_one_fwd
(
T
*
src_row
,
T
*
dst_row
,
const
T
weight
,
__device__
void
moe_cb_one_fwd
(
T
*
src_row
,
T
*
dst_row
,
const
T
weight
,
const
int
cols
)
{
const
int
cols
)
{
assert
(
cols
%
pack_size
==
0
);
assert
(
cols
%
pack_size
==
0
);
const
int
bpack_size
=
block_size
*
pack_size
;
const
int
bpack_size
=
block_size
*
pack_size
;
...
@@ -131,6 +134,7 @@ __device__ void moe_cb_one_fwd(T *src_row, T *dst_row, const T weight,
...
@@ -131,6 +134,7 @@ __device__ void moe_cb_one_fwd(T *src_row, T *dst_row, const T weight,
template
<
typename
T
,
int
block_size
,
int
pack_size
>
template
<
typename
T
,
int
block_size
,
int
pack_size
>
__device__
void
moe_cb_one_bwd
(
T
*
src_row
,
T
*
dst_row
,
T
*
tks_row
,
__device__
void
moe_cb_one_bwd
(
T
*
src_row
,
T
*
dst_row
,
T
*
tks_row
,
T
*
weight_grad
,
const
T
weight
,
const
int
cols
)
{
T
*
weight_grad
,
const
T
weight
,
const
int
cols
)
{
assert
(
cols
%
pack_size
==
0
);
assert
(
cols
%
pack_size
==
0
);
const
int
bpack_size
=
block_size
*
pack_size
;
const
int
bpack_size
=
block_size
*
pack_size
;
...
@@ -160,13 +164,15 @@ __device__ void moe_cb_one_bwd(T *src_row, T *dst_row, T *tks_row,
...
@@ -160,13 +164,15 @@ __device__ void moe_cb_one_bwd(T *src_row, T *dst_row, T *tks_row,
blockReduce
<
ReduceType
::
kSum
,
1
>
(
&
thread_sum
);
blockReduce
<
ReduceType
::
kSum
,
1
>
(
&
thread_sum
);
if
(
threadIdx
.
x
==
0
)
*
weight_grad
=
static_cast
<
T
>
(
thread_sum
);
if
(
threadIdx
.
x
==
0
)
*
weight_grad
=
static_cast
<
T
>
(
thread_sum
);
}
}
template
<
typename
T
,
int
block_size
,
int
pack_size
>
template
<
typename
T
,
int
block_size
,
int
pack_size
>
__device__
void
moe_cb_two_fwd
(
T
*
src_row1
,
T
*
src_row2
,
T
*
dst_row
,
__device__
void
moe_cb_two_fwd
(
T
*
src_row1
,
T
*
src_row2
,
T
*
dst_row
,
const
T
weight1
,
const
T
weight2
,
const
T
weight1
,
const
T
weight2
,
const
int
cols
)
{
const
int
cols
)
{
assert
(
cols
%
pack_size
==
0
);
assert
(
cols
%
pack_size
==
0
);
const
int
bpack_size
=
block_size
*
pack_size
;
const
int
bpack_size
=
block_size
*
pack_size
;
...
@@ -198,6 +204,7 @@ __device__ void moe_cb_two_bwd(T *src_row1, T *src_row2, T *dst_row,
...
@@ -198,6 +204,7 @@ __device__ void moe_cb_two_bwd(T *src_row1, T *src_row2, T *dst_row,
T
*
tks_row1
,
T
*
tks_row2
,
T
*
weight_grad1
,
T
*
tks_row1
,
T
*
tks_row2
,
T
*
weight_grad1
,
T
*
weight_grad2
,
const
T
weight1
,
T
*
weight_grad2
,
const
T
weight1
,
const
T
weight2
,
const
int
cols
)
{
const
T
weight2
,
const
int
cols
)
{
assert
(
cols
%
pack_size
==
0
);
assert
(
cols
%
pack_size
==
0
);
const
int
bpack_size
=
block_size
*
pack_size
;
const
int
bpack_size
=
block_size
*
pack_size
;
...
@@ -244,6 +251,7 @@ template <typename T, int block_size, int pack_size>
...
@@ -244,6 +251,7 @@ template <typename T, int block_size, int pack_size>
__device__
void
moe_dpch_fwd_selector
(
T
*
src_row
,
T
*
dst_row1
,
T
*
dst_row2
,
__device__
void
moe_dpch_fwd_selector
(
T
*
src_row
,
T
*
dst_row1
,
T
*
dst_row2
,
const
int
cols
,
const
int
indicator1
,
const
int
cols
,
const
int
indicator1
,
const
int
indicator2
)
{
const
int
indicator2
)
{
if
(
indicator1
!=
0
&&
indicator2
!=
0
)
if
(
indicator1
!=
0
&&
indicator2
!=
0
)
moe_dpch_two_fwd
<
T
,
block_size
,
pack_size
>
(
src_row
,
dst_row1
,
dst_row2
,
moe_dpch_two_fwd
<
T
,
block_size
,
pack_size
>
(
src_row
,
dst_row1
,
dst_row2
,
cols
);
cols
);
...
@@ -259,6 +267,7 @@ template <typename T, int block_size, int pack_size>
...
@@ -259,6 +267,7 @@ template <typename T, int block_size, int pack_size>
__device__
void
moe_dpch_bwd_selector
(
T
*
src_row
,
T
*
dst_row1
,
T
*
dst_row2
,
__device__
void
moe_dpch_bwd_selector
(
T
*
src_row
,
T
*
dst_row1
,
T
*
dst_row2
,
const
int
cols
,
const
int
indicator1
,
const
int
cols
,
const
int
indicator1
,
const
int
indicator2
)
{
const
int
indicator2
)
{
if
(
indicator1
!=
0
&&
indicator2
!=
0
)
if
(
indicator1
!=
0
&&
indicator2
!=
0
)
moe_dpch_two_bwd
<
T
,
block_size
,
pack_size
>
(
src_row
,
dst_row1
,
dst_row2
,
moe_dpch_two_bwd
<
T
,
block_size
,
pack_size
>
(
src_row
,
dst_row1
,
dst_row2
,
cols
);
cols
);
...
@@ -274,6 +283,7 @@ template <typename T, int block_size, int pack_size>
...
@@ -274,6 +283,7 @@ template <typename T, int block_size, int pack_size>
__global__
void
moe_dpch_fwd_kernel
(
T
*
batch_tokens
,
T
*
expert_input
,
__global__
void
moe_dpch_fwd_kernel
(
T
*
batch_tokens
,
T
*
expert_input
,
int
*
mask1
,
int
*
mask2
,
int
*
dest1
,
int
*
mask1
,
int
*
mask2
,
int
*
dest1
,
int
*
dest2
,
const
int
h
)
{
int
*
dest2
,
const
int
h
)
{
int
row
=
blockIdx
.
x
;
int
row
=
blockIdx
.
x
;
int
indicator2
=
mask2
==
nullptr
?
0
:
mask2
[
row
];
int
indicator2
=
mask2
==
nullptr
?
0
:
mask2
[
row
];
moe_dpch_fwd_selector
<
T
,
block_size
,
pack_size
>
(
moe_dpch_fwd_selector
<
T
,
block_size
,
pack_size
>
(
...
@@ -285,6 +295,7 @@ template <typename T, int block_size, int pack_size>
...
@@ -285,6 +295,7 @@ template <typename T, int block_size, int pack_size>
__global__
void
moe_dpch_bwd_kernel
(
T
*
tokens_grad
,
T
*
expert_grad
,
int
*
mask1
,
__global__
void
moe_dpch_bwd_kernel
(
T
*
tokens_grad
,
T
*
expert_grad
,
int
*
mask1
,
int
*
mask2
,
int
*
dest1
,
int
*
dest2
,
int
*
mask2
,
int
*
dest1
,
int
*
dest2
,
const
int
h
)
{
const
int
h
)
{
int
row
=
blockIdx
.
x
;
int
row
=
blockIdx
.
x
;
int
indicator2
=
mask2
==
nullptr
?
0
:
mask2
[
row
];
int
indicator2
=
mask2
==
nullptr
?
0
:
mask2
[
row
];
moe_dpch_bwd_selector
<
T
,
block_size
,
pack_size
>
(
moe_dpch_bwd_selector
<
T
,
block_size
,
pack_size
>
(
...
@@ -299,6 +310,7 @@ __device__ void moe_cb_fwd_selector(T *src_row1, T *src_row2, T *dst_row,
...
@@ -299,6 +310,7 @@ __device__ void moe_cb_fwd_selector(T *src_row1, T *src_row2, T *dst_row,
const
int
cols
,
const
T
weight1
,
const
int
cols
,
const
T
weight1
,
const
T
weight2
,
const
int
indicator1
,
const
T
weight2
,
const
int
indicator1
,
const
int
indicator2
)
{
const
int
indicator2
)
{
if
(
indicator1
!=
0
&&
indicator2
!=
0
)
if
(
indicator1
!=
0
&&
indicator2
!=
0
)
moe_cb_two_fwd
<
T
,
block_size
,
pack_size
>
(
src_row1
,
src_row2
,
dst_row
,
moe_cb_two_fwd
<
T
,
block_size
,
pack_size
>
(
src_row1
,
src_row2
,
dst_row
,
weight1
,
weight2
,
cols
);
weight1
,
weight2
,
cols
);
...
@@ -316,6 +328,7 @@ __device__ void moe_cb_bwd_selector(T *src_row1, T *src_row2, T *dst_row,
...
@@ -316,6 +328,7 @@ __device__ void moe_cb_bwd_selector(T *src_row1, T *src_row2, T *dst_row,
T
*
wt_grad1
,
T
*
wt_grad2
,
const
T
weight1
,
T
*
wt_grad1
,
T
*
wt_grad2
,
const
T
weight1
,
const
T
weight2
,
const
int
indicator1
,
const
T
weight2
,
const
int
indicator1
,
const
int
indicator2
)
{
const
int
indicator2
)
{
if
(
indicator1
!=
0
&&
indicator2
!=
0
)
if
(
indicator1
!=
0
&&
indicator2
!=
0
)
moe_cb_two_bwd
<
T
,
block_size
,
pack_size
>
(
src_row1
,
src_row2
,
dst_row
,
moe_cb_two_bwd
<
T
,
block_size
,
pack_size
>
(
src_row1
,
src_row2
,
dst_row
,
tks_row1
,
tks_row2
,
wt_grad1
,
tks_row1
,
tks_row2
,
wt_grad1
,
...
@@ -335,6 +348,7 @@ __global__ void moe_cb_fwd_kernel(T *expert_tokens, T *combine_tokens,
...
@@ -335,6 +348,7 @@ __global__ void moe_cb_fwd_kernel(T *expert_tokens, T *combine_tokens,
T
*
logits
,
int
*
mask1
,
int
*
mask2
,
int
*
dest1
,
T
*
logits
,
int
*
mask1
,
int
*
mask2
,
int
*
dest1
,
int
*
dest2
,
const
int
e
,
const
int
c
,
int
*
dest2
,
const
int
e
,
const
int
c
,
const
int
h
)
{
const
int
h
)
{
int
row
=
blockIdx
.
x
,
eid1
=
dest1
[
row
]
/
c
,
eid2
=
dest2
[
row
]
/
c
;
int
row
=
blockIdx
.
x
,
eid1
=
dest1
[
row
]
/
c
,
eid2
=
dest2
[
row
]
/
c
;
int
indicator2
=
mask2
==
nullptr
?
0
:
mask2
[
row
];
int
indicator2
=
mask2
==
nullptr
?
0
:
mask2
[
row
];
T
*
row_log
=
logits
+
(
row
*
e
);
T
*
row_log
=
logits
+
(
row
*
e
);
...
@@ -349,6 +363,7 @@ __global__ void moe_cb_bwd_kernel(T *tokens_grad, T *expert_grad, T *tks,
...
@@ -349,6 +363,7 @@ __global__ void moe_cb_bwd_kernel(T *tokens_grad, T *expert_grad, T *tks,
T
*
logits
,
T
*
logits_grad
,
int
*
mask1
,
T
*
logits
,
T
*
logits_grad
,
int
*
mask1
,
int
*
mask2
,
int
*
dest1
,
int
*
dest2
,
int
*
mask2
,
int
*
dest1
,
int
*
dest2
,
const
int
e
,
const
int
c
,
const
int
h
)
{
const
int
e
,
const
int
c
,
const
int
h
)
{
int
row
=
blockIdx
.
x
,
eid1
=
dest1
[
row
]
/
c
,
eid2
=
dest2
[
row
]
/
c
;
int
row
=
blockIdx
.
x
,
eid1
=
dest1
[
row
]
/
c
,
eid2
=
dest2
[
row
]
/
c
;
int
indicator2
=
mask2
==
nullptr
?
0
:
mask2
[
row
];
int
indicator2
=
mask2
==
nullptr
?
0
:
mask2
[
row
];
T
*
row_log
=
logits
+
(
row
*
e
),
*
row_grad
=
logits_grad
+
(
row
*
e
);
T
*
row_log
=
logits
+
(
row
*
e
),
*
row_grad
=
logits_grad
+
(
row
*
e
);
...
@@ -364,6 +379,7 @@ __global__ void moe_cb_bwd_kernel(T *tokens_grad, T *expert_grad, T *tks,
...
@@ -364,6 +379,7 @@ __global__ void moe_cb_bwd_kernel(T *tokens_grad, T *expert_grad, T *tks,
template
<
int
block_size
,
int
pack_size
>
template
<
int
block_size
,
int
pack_size
>
__global__
void
cumsum_kernel
(
int
*
inputs
,
int
*
outputs
,
const
int
s
,
__global__
void
cumsum_kernel
(
int
*
inputs
,
int
*
outputs
,
const
int
s
,
const
int
e
)
{
const
int
e
)
{
assert
(
s
%
pack_size
==
0
);
assert
(
s
%
pack_size
==
0
);
constexpr
int
bpack_size
=
block_size
*
pack_size
;
constexpr
int
bpack_size
=
block_size
*
pack_size
;
int
tid
=
threadIdx
.
x
,
bid
=
blockIdx
.
x
,
tps
=
tid
*
pack_size
,
last_sum
=
-
1
;
int
tid
=
threadIdx
.
x
,
bid
=
blockIdx
.
x
,
tps
=
tid
*
pack_size
,
last_sum
=
-
1
;
...
@@ -410,7 +426,8 @@ __global__ void cumsum_kernel(int *inputs, int *outputs, const int s,
...
@@ -410,7 +426,8 @@ __global__ void cumsum_kernel(int *inputs, int *outputs, const int s,
}
}
__syncthreads
();
__syncthreads
();
if
(
tid
==
0
)
temp
[
0
]
=
temp
[
block_size
];
if
(
tid
==
0
)
temp
[
0
]
=
temp
[
block_size
];
__syncthreads
();
__syncthreads
();
if
(
idx
+
tps
<
s
)
{
if
(
idx
+
tps
<
s
)
{
...
@@ -436,6 +453,7 @@ template <typename T>
...
@@ -436,6 +453,7 @@ template <typename T>
void
moe_dpch_fwd_launch
(
T
*
batch_tokens
,
T
*
expert_input
,
int
*
mask1
,
void
moe_dpch_fwd_launch
(
T
*
batch_tokens
,
T
*
expert_input
,
int
*
mask1
,
int
*
mask2
,
int
*
dest1
,
int
*
dest2
,
const
int
s
,
int
*
mask2
,
int
*
dest1
,
int
*
dest2
,
const
int
s
,
const
int
h
)
{
const
int
h
)
{
if
(
h
<
256
)
if
(
h
<
256
)
moe_dpch_fwd_kernel
<
T
,
32
,
4
>
moe_dpch_fwd_kernel
<
T
,
32
,
4
>
<<<
s
,
32
>>>
(
batch_tokens
,
expert_input
,
mask1
,
mask2
,
dest1
,
dest2
,
h
);
<<<
s
,
32
>>>
(
batch_tokens
,
expert_input
,
mask1
,
mask2
,
dest1
,
dest2
,
h
);
...
@@ -456,6 +474,7 @@ void moe_dpch_fwd_launch(T *batch_tokens, T *expert_input, int *mask1,
...
@@ -456,6 +474,7 @@ void moe_dpch_fwd_launch(T *batch_tokens, T *expert_input, int *mask1,
template
<
typename
T
>
template
<
typename
T
>
void
moe_dpch_bwd_launch
(
T
*
tokens_grad
,
T
*
expert_grad
,
int
*
mask1
,
int
*
mask2
,
void
moe_dpch_bwd_launch
(
T
*
tokens_grad
,
T
*
expert_grad
,
int
*
mask1
,
int
*
mask2
,
int
*
dest1
,
int
*
dest2
,
const
int
s
,
const
int
h
)
{
int
*
dest1
,
int
*
dest2
,
const
int
s
,
const
int
h
)
{
if
(
h
<
256
)
if
(
h
<
256
)
moe_dpch_bwd_kernel
<
T
,
32
,
4
>
moe_dpch_bwd_kernel
<
T
,
32
,
4
>
<<<
s
,
32
>>>
(
tokens_grad
,
expert_grad
,
mask1
,
mask2
,
dest1
,
dest2
,
h
);
<<<
s
,
32
>>>
(
tokens_grad
,
expert_grad
,
mask1
,
mask2
,
dest1
,
dest2
,
h
);
...
@@ -477,6 +496,7 @@ template <typename T>
...
@@ -477,6 +496,7 @@ template <typename T>
void
moe_cb_fwd_launch
(
T
*
expert_tokens
,
T
*
combine_tokens
,
T
*
logits
,
void
moe_cb_fwd_launch
(
T
*
expert_tokens
,
T
*
combine_tokens
,
T
*
logits
,
int
*
mask1
,
int
*
mask2
,
int
*
dest1
,
int
*
dest2
,
int
*
mask1
,
int
*
mask2
,
int
*
dest1
,
int
*
dest2
,
const
int
s
,
const
int
e
,
const
int
c
,
const
int
h
)
{
const
int
s
,
const
int
e
,
const
int
c
,
const
int
h
)
{
if
(
h
<
256
)
if
(
h
<
256
)
moe_cb_fwd_kernel
<
T
,
32
,
4
><<<
s
,
32
>>>
(
expert_tokens
,
combine_tokens
,
moe_cb_fwd_kernel
<
T
,
32
,
4
><<<
s
,
32
>>>
(
expert_tokens
,
combine_tokens
,
logits
,
mask1
,
mask2
,
dest1
,
dest2
,
logits
,
mask1
,
mask2
,
dest1
,
dest2
,
...
@@ -504,11 +524,12 @@ void moe_cb_bwd_launch(T *tokens_grad, T *expert_grad, T *tks, T *logits,
...
@@ -504,11 +524,12 @@ void moe_cb_bwd_launch(T *tokens_grad, T *expert_grad, T *tks, T *logits,
T
*
logits_grad
,
int
*
mask1
,
int
*
mask2
,
int
*
dest1
,
T
*
logits_grad
,
int
*
mask1
,
int
*
mask2
,
int
*
dest1
,
int
*
dest2
,
const
int
s
,
const
int
e
,
const
int
c
,
int
*
dest2
,
const
int
s
,
const
int
e
,
const
int
c
,
const
int
h
)
{
const
int
h
)
{
if
(
h
<
256
)
if
(
h
<
256
)
moe_cb_bwd_kernel
<
T
,
32
,
4
><<<
s
,
32
>>>
(
tokens_grad
,
expert_grad
,
tks
,
moe_cb_bwd_kernel
<
T
,
32
,
4
><<<
s
,
32
>>>
(
tokens_grad
,
expert_grad
,
tks
,
logits
,
logits_grad
,
mask1
,
mask2
,
logits
,
logits_grad
,
mask1
,
mask2
,
dest1
,
dest2
,
e
,
c
,
h
);
dest1
,
dest2
,
e
,
c
,
h
);
else
// if (h < 512)
else
// if (h < 512)
moe_cb_bwd_kernel
<
T
,
64
,
4
><<<
s
,
64
>>>
(
tokens_grad
,
expert_grad
,
tks
,
moe_cb_bwd_kernel
<
T
,
64
,
4
><<<
s
,
64
>>>
(
tokens_grad
,
expert_grad
,
tks
,
logits
,
logits_grad
,
mask1
,
mask2
,
logits
,
logits_grad
,
mask1
,
mask2
,
dest1
,
dest2
,
e
,
c
,
h
);
dest1
,
dest2
,
e
,
c
,
h
);
...
@@ -523,6 +544,7 @@ void moe_cb_bwd_launch(T *tokens_grad, T *expert_grad, T *tks, T *logits,
...
@@ -523,6 +544,7 @@ void moe_cb_bwd_launch(T *tokens_grad, T *expert_grad, T *tks, T *logits,
}
}
void
cumsum_launch
(
int
*
inputs
,
int
*
outputs
,
const
int
s
,
const
int
e
)
{
void
cumsum_launch
(
int
*
inputs
,
int
*
outputs
,
const
int
s
,
const
int
e
)
{
if
(
s
<=
256
)
if
(
s
<=
256
)
cumsum_kernel
<
256
,
1
><<<
e
,
256
>>>
(
inputs
,
outputs
,
s
,
e
);
cumsum_kernel
<
256
,
1
><<<
e
,
256
>>>
(
inputs
,
outputs
,
s
,
e
);
else
if
(
s
<=
512
)
else
if
(
s
<=
512
)
...
@@ -537,26 +559,27 @@ void cumsum_launch(int *inputs, int *outputs, const int s, const int e) {
...
@@ -537,26 +559,27 @@ void cumsum_launch(int *inputs, int *outputs, const int s, const int e) {
// API FUNCTIONS --------------------------------
// API FUNCTIONS --------------------------------
#define DISPATCH_FLOAT_AND_HALF(TYPE, NAME, ...) \
#define DISPATCH_FLOAT_AND_HALF(TYPE, NAME, ...)
\
switch
(
TYPE
)
{
\
switch
(
TYPE
)
{
\
case
at
::
ScalarType
::
Float
:
{
\
case
at
::
ScalarType
::
Float
:
{
\
using
scalar_t
=
float
;
\
using
scalar_t
=
float
;
\
__VA_ARGS__
;
\
__VA_ARGS__
;
\
break
;
\
break
;
\
}
\
}
\
case
at
::
ScalarType
::
Half
:
{
\
case
at
::
ScalarType
::
Half
:
{
\
using
scalar_t
=
at
::
Half
;
\
using
scalar_t
=
at
::
Half
;
\
__VA_ARGS__
;
\
__VA_ARGS__
;
\
break
;
\
break
;
\
}
\
}
\
default:
\
default:
\
AT_ERROR
(
#
NAME
,
" not implemented yet for specific data type."
);
\
AT_ERROR
(
#
NAME
,
" not implemented yet for specific data type."
);
\
}
}
torch
::
Tensor
moe_dispatch_cuda_forward
(
int
s
,
int
ec
,
int
h
,
torch
::
Tensor
moe_dispatch_cuda_forward
(
int
s
,
int
ec
,
int
h
,
torch
::
Tensor
batch_tokens
,
torch
::
Tensor
batch_tokens
,
torch
::
Tensor
mask
,
torch
::
Tensor
mask
,
torch
::
Tensor
dest_idx
)
{
torch
::
Tensor
dest_idx
)
{
assert
(
h
%
16
==
0
);
assert
(
h
%
16
==
0
);
auto
res
=
torch
::
zeros
(
auto
res
=
torch
::
zeros
(
{
ec
,
h
},
{
ec
,
h
},
...
@@ -578,6 +601,7 @@ torch::Tensor moe_dispatch_cuda_backward(int s, int ec, int h,
...
@@ -578,6 +601,7 @@ torch::Tensor moe_dispatch_cuda_backward(int s, int ec, int h,
torch
::
Tensor
expert_grad
,
torch
::
Tensor
expert_grad
,
torch
::
Tensor
mask
,
torch
::
Tensor
mask
,
torch
::
Tensor
dest_idx
)
{
torch
::
Tensor
dest_idx
)
{
assert
(
h
%
16
==
0
);
assert
(
h
%
16
==
0
);
auto
res
=
torch
::
zeros
(
auto
res
=
torch
::
zeros
(
{
s
,
h
},
torch
::
dtype
(
expert_grad
.
dtype
()).
device
(
expert_grad
.
device
()));
{
s
,
h
},
torch
::
dtype
(
expert_grad
.
dtype
()).
device
(
expert_grad
.
device
()));
...
@@ -598,6 +622,7 @@ torch::Tensor moe_combine_cuda_forward(int s, int e, int c, int h,
...
@@ -598,6 +622,7 @@ torch::Tensor moe_combine_cuda_forward(int s, int e, int c, int h,
torch
::
Tensor
expert_tokens
,
torch
::
Tensor
expert_tokens
,
torch
::
Tensor
logits
,
torch
::
Tensor
mask
,
torch
::
Tensor
logits
,
torch
::
Tensor
mask
,
torch
::
Tensor
dest_idx
)
{
torch
::
Tensor
dest_idx
)
{
assert
(
h
%
16
==
0
);
assert
(
h
%
16
==
0
);
assert
(
expert_tokens
.
dtype
()
==
logits
.
dtype
());
assert
(
expert_tokens
.
dtype
()
==
logits
.
dtype
());
...
@@ -618,10 +643,11 @@ torch::Tensor moe_combine_cuda_forward(int s, int e, int c, int h,
...
@@ -618,10 +643,11 @@ torch::Tensor moe_combine_cuda_forward(int s, int e, int c, int h,
return
res
;
return
res
;
}
}
std
::
vector
<
torch
::
Tensor
>
moe_combine_cuda_backward
(
std
::
vector
<
torch
::
Tensor
>
int
s
,
int
e
,
int
c
,
int
h
,
torch
::
Tensor
tokens_grad
,
moe_combine_cuda_backward
(
int
s
,
int
e
,
int
c
,
int
h
,
torch
::
Tensor
tokens_grad
,
torch
::
Tensor
expert_tokens
,
torch
::
Tensor
logits
,
torch
::
Tensor
mask
,
torch
::
Tensor
expert_tokens
,
torch
::
Tensor
logits
,
torch
::
Tensor
dest_idx
)
{
torch
::
Tensor
mask
,
torch
::
Tensor
dest_idx
)
{
assert
(
h
%
16
==
0
);
assert
(
h
%
16
==
0
);
assert
(
tokens_grad
.
dtype
()
==
expert_tokens
.
dtype
());
assert
(
tokens_grad
.
dtype
()
==
expert_tokens
.
dtype
());
assert
(
expert_tokens
.
dtype
()
==
logits
.
dtype
());
assert
(
expert_tokens
.
dtype
()
==
logits
.
dtype
());
...
@@ -647,6 +673,7 @@ std::vector<torch::Tensor> moe_combine_cuda_backward(
...
@@ -647,6 +673,7 @@ std::vector<torch::Tensor> moe_combine_cuda_backward(
}
}
torch
::
Tensor
cumsum_sub_one_in_dim0
(
torch
::
Tensor
mask
)
{
torch
::
Tensor
cumsum_sub_one_in_dim0
(
torch
::
Tensor
mask
)
{
assert
(
mask
.
dim
()
==
2
);
assert
(
mask
.
dim
()
==
2
);
assert
(
mask
.
dtype
()
==
torch
::
kInt32
);
assert
(
mask
.
dtype
()
==
torch
::
kInt32
);
...
...
colossalai/kernel/cuda_native/csrc/multi_tensor_l2norm_kernel.cu
View file @
58580b50
...
@@ -16,8 +16,7 @@
...
@@ -16,8 +16,7 @@
#define BLOCK_SIZE 512
#define BLOCK_SIZE 512
#define ILP 4
#define ILP 4
template
<
typename
T
>
template
<
typename
T
>
__device__
__forceinline__
bool
is_aligned
(
T
*
p
)
{
__device__
__forceinline__
bool
is_aligned
(
T
*
p
)
{
return
((
uint64_t
)
p
)
%
(
ILP
*
sizeof
(
T
))
==
0
;
return
((
uint64_t
)
p
)
%
(
ILP
*
sizeof
(
T
))
==
0
;
}
}
...
@@ -29,12 +28,11 @@ __device__ __forceinline__ void load_store(T *dst, T *src, int dst_offset,
...
@@ -29,12 +28,11 @@ __device__ __forceinline__ void load_store(T *dst, T *src, int dst_offset,
((
LT
*
)
dst
)[
dst_offset
]
=
((
LT
*
)
src
)[
src_offset
];
((
LT
*
)
dst
)[
dst_offset
]
=
((
LT
*
)
src
)[
src_offset
];
}
}
template
<
typename
x_t
>
template
<
typename
x_t
>
struct
L2NormFunctor
{
struct
L2NormFunctor
{
__device__
__forceinline__
void
__device__
__forceinline__
void
operator
()(
operator
()(
int
chunk_size
,
volatile
int
*
noop_gmem
,
TensorListMetadata
<
1
>
&
tl
,
int
chunk_size
,
volatile
int
*
noop_gmem
,
TensorListMetadata
<
1
>
&
tl
,
float
*
output
,
float
*
output_per_tensor
,
bool
per_tensor
,
float
*
output
,
float
*
output_per_tensor
,
bool
per_tensor
,
int
max_chunks_per_tensor
)
{
int
max_chunks_per_tensor
)
{
// I'd like this kernel to propagate infs/nans.
// I'd like this kernel to propagate infs/nans.
// if(*noop_gmem == 1)
// if(*noop_gmem == 1)
// return;
// return;
...
@@ -50,8 +48,8 @@ struct L2NormFunctor {
...
@@ -50,8 +48,8 @@ struct L2NormFunctor {
__shared__
float
s_vals
[
512
];
__shared__
float
s_vals
[
512
];
float
vals
[
ILP
];
// = {0}; // this probably works too but I want to be
float
//
sure...
vals
[
ILP
];
// = {0}; // this probably works too but I want to be
sure...
x_t
r_x
[
ILP
];
x_t
r_x
[
ILP
];
for
(
int
i
=
0
;
i
<
ILP
;
i
++
)
{
for
(
int
i
=
0
;
i
<
ILP
;
i
++
)
{
vals
[
i
]
=
0.
f
;
vals
[
i
]
=
0.
f
;
...
@@ -86,14 +84,15 @@ struct L2NormFunctor {
...
@@ -86,14 +84,15 @@ struct L2NormFunctor {
}
}
float
val
=
0.
f
;
float
val
=
0.
f
;
for
(
int
i
=
0
;
i
<
ILP
;
i
++
)
val
+=
vals
[
i
];
for
(
int
i
=
0
;
i
<
ILP
;
i
++
)
val
+=
vals
[
i
];
float
final
=
reduce_block_into_lanes
(
s_vals
,
val
);
float
final
=
reduce_block_into_lanes
(
s_vals
,
val
);
if
(
threadIdx
.
x
==
0
)
{
if
(
threadIdx
.
x
==
0
)
{
if
(
!
isfinite
(
final
))
if
(
!
isfinite
(
final
))
*
noop_gmem
=
*
noop_gmem
=
1
;
// Blindly fire off a write. These will race but that's ok.
1
;
// Blindly fire off a write. These will race but that's ok.
output
[
blockIdx
.
x
]
+=
final
;
output
[
blockIdx
.
x
]
+=
final
;
if
(
per_tensor
)
if
(
per_tensor
)
output_per_tensor
[(
tl
.
start_tensor_this_launch
+
tensor_loc
)
*
output_per_tensor
[(
tl
.
start_tensor_this_launch
+
tensor_loc
)
*
...
@@ -105,12 +104,11 @@ struct L2NormFunctor {
...
@@ -105,12 +104,11 @@ struct L2NormFunctor {
// Probably better to template, but since we are not likely to support other
// Probably better to template, but since we are not likely to support other
// norm
// norm
template
<
typename
x_t
>
template
<
typename
x_t
>
struct
MaxNormFunctor
{
struct
MaxNormFunctor
{
__device__
__forceinline__
void
__device__
__forceinline__
void
operator
()(
operator
()(
int
chunk_size
,
volatile
int
*
noop_gmem
,
TensorListMetadata
<
1
>
&
tl
,
int
chunk_size
,
volatile
int
*
noop_gmem
,
TensorListMetadata
<
1
>
&
tl
,
float
*
output
,
float
*
output_per_tensor
,
bool
per_tensor
,
float
*
output
,
float
*
output_per_tensor
,
bool
per_tensor
,
int
max_chunks_per_tensor
)
{
int
max_chunks_per_tensor
)
{
// I'd like this kernel to propagate infs/nans.
// I'd like this kernel to propagate infs/nans.
// if(*noop_gmem == 1)
// if(*noop_gmem == 1)
// return;
// return;
...
@@ -126,8 +124,8 @@ struct MaxNormFunctor {
...
@@ -126,8 +124,8 @@ struct MaxNormFunctor {
__shared__
float
s_vals
[
512
];
__shared__
float
s_vals
[
512
];
float
vals
[
ILP
];
// = {0}; // this probably works too but I want to be
float
//
sure...
vals
[
ILP
];
// = {0}; // this probably works too but I want to be
sure...
x_t
r_x
[
ILP
];
x_t
r_x
[
ILP
];
for
(
int
i
=
0
;
i
<
ILP
;
i
++
)
{
for
(
int
i
=
0
;
i
<
ILP
;
i
++
)
{
vals
[
i
]
=
0.
f
;
vals
[
i
]
=
0.
f
;
...
@@ -162,14 +160,15 @@ struct MaxNormFunctor {
...
@@ -162,14 +160,15 @@ struct MaxNormFunctor {
}
}
float
val
=
0.
f
;
float
val
=
0.
f
;
for
(
int
i
=
0
;
i
<
ILP
;
i
++
)
val
=
fmaxf
(
fabsf
(
val
),
fabsf
(
vals
[
i
]));
for
(
int
i
=
0
;
i
<
ILP
;
i
++
)
val
=
fmaxf
(
fabsf
(
val
),
fabsf
(
vals
[
i
]));
float
final
=
reduce_block_into_lanes_max_op
(
s_vals
,
val
);
float
final
=
reduce_block_into_lanes_max_op
(
s_vals
,
val
);
if
(
threadIdx
.
x
==
0
)
{
if
(
threadIdx
.
x
==
0
)
{
if
(
!
isfinite
(
final
))
if
(
!
isfinite
(
final
))
*
noop_gmem
=
*
noop_gmem
=
1
;
// Blindly fire off a write. These will race but that's ok.
1
;
// Blindly fire off a write. These will race but that's ok.
output
[
blockIdx
.
x
]
=
fmaxf
(
fabsf
(
output
[
blockIdx
.
x
]),
fabsf
(
final
));
output
[
blockIdx
.
x
]
=
fmaxf
(
fabsf
(
output
[
blockIdx
.
x
]),
fabsf
(
final
));
if
(
per_tensor
)
if
(
per_tensor
)
output_per_tensor
[(
tl
.
start_tensor_this_launch
+
tensor_loc
)
*
output_per_tensor
[(
tl
.
start_tensor_this_launch
+
tensor_loc
)
*
...
@@ -186,11 +185,13 @@ __global__ void cleanup(float *output, float *output_per_tensor, float *ret,
...
@@ -186,11 +185,13 @@ __global__ void cleanup(float *output, float *output_per_tensor, float *ret,
if
(
blockIdx
.
x
==
0
)
{
if
(
blockIdx
.
x
==
0
)
{
float
val
=
0
;
float
val
=
0
;
if
(
threadIdx
.
x
<
320
)
val
=
output
[
threadIdx
.
x
];
if
(
threadIdx
.
x
<
320
)
val
=
output
[
threadIdx
.
x
];
float
final
=
reduce_block_into_lanes
(
vals
,
val
);
float
final
=
reduce_block_into_lanes
(
vals
,
val
);
if
(
threadIdx
.
x
==
0
)
*
ret
=
sqrt
(
final
);
if
(
threadIdx
.
x
==
0
)
*
ret
=
sqrt
(
final
);
}
}
if
(
per_tensor
)
{
if
(
per_tensor
)
{
...
@@ -203,7 +204,8 @@ __global__ void cleanup(float *output, float *output_per_tensor, float *ret,
...
@@ -203,7 +204,8 @@ __global__ void cleanup(float *output, float *output_per_tensor, float *ret,
float
final
=
reduce_block_into_lanes
(
vals
,
val
);
float
final
=
reduce_block_into_lanes
(
vals
,
val
);
if
(
threadIdx
.
x
==
0
)
ret_per_tensor
[
blockIdx
.
x
]
=
sqrt
(
final
);
if
(
threadIdx
.
x
==
0
)
ret_per_tensor
[
blockIdx
.
x
]
=
sqrt
(
final
);
}
}
}
}
...
@@ -215,14 +217,17 @@ __global__ void cleanup_v2(float *output, float *output_per_tensor, float *ret,
...
@@ -215,14 +217,17 @@ __global__ void cleanup_v2(float *output, float *output_per_tensor, float *ret,
if
(
blockIdx
.
x
==
0
)
{
if
(
blockIdx
.
x
==
0
)
{
float
val
=
0
;
float
val
=
0
;
if
(
threadIdx
.
x
<
320
)
val
=
output
[
threadIdx
.
x
];
if
(
threadIdx
.
x
<
320
)
val
=
output
[
threadIdx
.
x
];
if
(
norm_type
==
0
)
{
if
(
norm_type
==
0
)
{
float
final
=
reduce_block_into_lanes_max_op
(
vals
,
val
);
float
final
=
reduce_block_into_lanes_max_op
(
vals
,
val
);
if
(
threadIdx
.
x
==
0
)
*
ret
=
alpha
*
(
*
ret
)
+
beta
*
final
;
if
(
threadIdx
.
x
==
0
)
*
ret
=
alpha
*
(
*
ret
)
+
beta
*
final
;
}
else
{
}
else
{
float
final
=
reduce_block_into_lanes
(
vals
,
val
);
float
final
=
reduce_block_into_lanes
(
vals
,
val
);
if
(
threadIdx
.
x
==
0
)
*
ret
=
sqrt
(
alpha
*
(
*
ret
)
*
(
*
ret
)
+
beta
*
final
);
if
(
threadIdx
.
x
==
0
)
*
ret
=
sqrt
(
alpha
*
(
*
ret
)
*
(
*
ret
)
+
beta
*
final
);
}
}
}
}
...
@@ -255,10 +260,10 @@ __global__ void cleanup_v2(float *output, float *output_per_tensor, float *ret,
...
@@ -255,10 +260,10 @@ __global__ void cleanup_v2(float *output, float *output_per_tensor, float *ret,
}
}
}
}
std
::
tuple
<
at
::
Tensor
,
at
::
Tensor
>
multi_tensor_l2norm_cuda
(
std
::
tuple
<
at
::
Tensor
,
at
::
Tensor
>
int
chunk_size
,
at
::
Tensor
noop_flag
,
multi_tensor_l2norm_cuda
(
int
chunk_size
,
at
::
Tensor
noop_flag
,
std
::
vector
<
std
::
vector
<
at
::
Tensor
>>
tensor_lists
,
std
::
vector
<
std
::
vector
<
at
::
Tensor
>>
tensor_lists
,
at
::
optional
<
bool
>
per_tensor_python
)
{
at
::
optional
<
bool
>
per_tensor_python
)
{
bool
per_tensor
=
bool
per_tensor
=
per_tensor_python
.
has_value
()
?
per_tensor_python
.
value
()
:
false
;
per_tensor_python
.
has_value
()
?
per_tensor_python
.
value
()
:
false
;
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment