Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
58580b50
Unverified
Commit
58580b50
authored
May 17, 2022
by
ver217
Committed by
GitHub
May 17, 2022
Browse files
Revert "[NFC] Hotfix/format (#984)" (#986)
This reverts commit
0772828f
.
parent
0772828f
Changes
35
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
15 changed files
with
542 additions
and
478 deletions
+542
-478
colossalai/kernel/cuda_native/csrc/multi_tensor_lamb.cu
colossalai/kernel/cuda_native/csrc/multi_tensor_lamb.cu
+25
-25
colossalai/kernel/cuda_native/csrc/multi_tensor_scale_kernel.cu
...alai/kernel/cuda_native/csrc/multi_tensor_scale_kernel.cu
+7
-7
colossalai/kernel/cuda_native/csrc/multi_tensor_sgd_kernel.cu
...ssalai/kernel/cuda_native/csrc/multi_tensor_sgd_kernel.cu
+241
-169
colossalai/kernel/cuda_native/csrc/multihead_attention_1d.cpp
...ssalai/kernel/cuda_native/csrc/multihead_attention_1d.cpp
+96
-132
colossalai/kernel/cuda_native/csrc/multihead_attention_1d.h
colossalai/kernel/cuda_native/csrc/multihead_attention_1d.h
+12
-19
colossalai/kernel/cuda_native/csrc/scaled_masked_softmax_cuda.cu
...lai/kernel/cuda_native/csrc/scaled_masked_softmax_cuda.cu
+48
-33
colossalai/kernel/cuda_native/csrc/scaled_upper_triang_masked_softmax.cpp
...l/cuda_native/csrc/scaled_upper_triang_masked_softmax.cpp
+25
-20
colossalai/kernel/cuda_native/csrc/scaled_upper_triang_masked_softmax_cuda.cu
...da_native/csrc/scaled_upper_triang_masked_softmax_cuda.cu
+38
-28
colossalai/kernel/cuda_native/layer_norm.py
colossalai/kernel/cuda_native/layer_norm.py
+4
-3
colossalai/kernel/cuda_native/scaled_softmax.py
colossalai/kernel/cuda_native/scaled_softmax.py
+23
-13
colossalai/kernel/jit/bias_gelu.py
colossalai/kernel/jit/bias_gelu.py
+4
-8
colossalai/nn/layer/parallel_2d/layers.py
colossalai/nn/layer/parallel_2d/layers.py
+7
-7
colossalai/nn/layer/parallel_2p5d/layers.py
colossalai/nn/layer/parallel_2p5d/layers.py
+8
-8
colossalai/nn/layer/parallel_3d/layers.py
colossalai/nn/layer/parallel_3d/layers.py
+3
-3
colossalai/nn/layer/utils/common.py
colossalai/nn/layer/utils/common.py
+1
-3
No files found.
colossalai/kernel/cuda_native/csrc/multi_tensor_lamb.cu
View file @
58580b50
...
@@ -15,8 +15,7 @@
...
@@ -15,8 +15,7 @@
#define BLOCK_SIZE 512
#define BLOCK_SIZE 512
#define ILP 4
#define ILP 4
template
<
typename
T
>
template
<
typename
T
>
__device__
__forceinline__
bool
is_aligned
(
T
*
p
)
{
__device__
__forceinline__
bool
is_aligned
(
T
*
p
)
{
return
((
uint64_t
)
p
)
%
(
ILP
*
sizeof
(
T
))
==
0
;
return
((
uint64_t
)
p
)
%
(
ILP
*
sizeof
(
T
))
==
0
;
}
}
...
@@ -29,25 +28,24 @@ __device__ __forceinline__ void load_store(T *dst, T *src, int dst_offset,
...
@@ -29,25 +28,24 @@ __device__ __forceinline__ void load_store(T *dst, T *src, int dst_offset,
}
}
typedef
enum
{
typedef
enum
{
MOMENT_MODE_0
=
0
,
// L2 regularization mode
MOMENT_MODE_0
=
0
,
// L2 regularization mode
MOMENT_MODE_1
=
1
// Decoupled weight decay mode
MOMENT_MODE_1
=
1
// Decoupled weight decay mode
}
adamMode_t
;
}
adamMode_t
;
std
::
tuple
<
at
::
Tensor
,
at
::
Tensor
>
multi_tensor_l2norm_cuda
(
std
::
tuple
<
at
::
Tensor
,
at
::
Tensor
>
int
chunk_size
,
at
::
Tensor
noop_flag
,
multi_tensor_l2norm_cuda
(
int
chunk_size
,
at
::
Tensor
noop_flag
,
std
::
vector
<
std
::
vector
<
at
::
Tensor
>>
tensor_lists
,
std
::
vector
<
std
::
vector
<
at
::
Tensor
>>
tensor_lists
,
at
::
optional
<
bool
>
per_tensor_python
);
at
::
optional
<
bool
>
per_tensor_python
);
using
MATH_T
=
float
;
using
MATH_T
=
float
;
template
<
typename
T
>
template
<
typename
T
>
struct
LAMBStage1Functor
{
struct
LAMBStage1Functor
{
__device__
__forceinline__
void
__device__
__forceinline__
void
operator
()(
operator
()(
int
chunk_size
,
volatile
int
*
noop_gmem
,
TensorListMetadata
<
4
>
&
tl
,
int
chunk_size
,
volatile
int
*
noop_gmem
,
TensorListMetadata
<
4
>
&
tl
,
const
float
beta1
,
const
float
beta2
,
const
float
beta3
,
const
float
beta1
,
const
float
beta2
,
const
float
beta3
,
const
float
beta1_correction
,
const
float
beta2_correction
,
const
float
beta1_correction
,
const
float
beta2_correction
,
const
float
epsilon
,
adamMode_t
mode
,
const
float
decay
,
const
float
epsilon
,
adamMode_t
mode
,
const
float
decay
,
const
float
*
global_grad_norm
,
const
float
max_global_grad_norm
)
{
const
float
*
global_grad_norm
,
const
float
max_global_grad_norm
)
{
// I'd like this kernel to propagate infs/nans.
// I'd like this kernel to propagate infs/nans.
// if(*noop_gmem == 1)
// if(*noop_gmem == 1)
// return;
// return;
...
@@ -91,7 +89,8 @@ struct LAMBStage1Functor {
...
@@ -91,7 +89,8 @@ struct LAMBStage1Functor {
i_start
+=
blockDim
.
x
)
{
i_start
+=
blockDim
.
x
)
{
// load
// load
load_store
(
l_g
,
g
,
0
,
i_start
);
load_store
(
l_g
,
g
,
0
,
i_start
);
if
(
decay
!=
0
)
load_store
(
l_p
,
p
,
0
,
i_start
);
if
(
decay
!=
0
)
load_store
(
l_p
,
p
,
0
,
i_start
);
load_store
(
l_m
,
m
,
0
,
i_start
);
load_store
(
l_m
,
m
,
0
,
i_start
);
load_store
(
l_v
,
v
,
0
,
i_start
);
load_store
(
l_v
,
v
,
0
,
i_start
);
// unpack
// unpack
...
@@ -205,12 +204,12 @@ struct LAMBStage1Functor {
...
@@ -205,12 +204,12 @@ struct LAMBStage1Functor {
// Step 2 reads in 'update' value and per-tensor param_norm and update_norm.
// Step 2 reads in 'update' value and per-tensor param_norm and update_norm.
// It computes new parameter value.
// It computes new parameter value.
template
<
typename
T
>
template
<
typename
T
>
struct
LAMBStage2Functor
{
struct
LAMBStage2Functor
{
__device__
__forceinline__
void
__device__
__forceinline__
void
operator
()(
operator
()(
int
chunk_size
,
volatile
int
*
noop_gmem
,
TensorListMetadata
<
2
>
&
tl
,
int
chunk_size
,
volatile
int
*
noop_gmem
,
TensorListMetadata
<
2
>
&
tl
,
const
float
*
per_tensor_param_norm
,
const
float
*
per_tensor_param_norm
,
const
float
*
per_tensor_update_norm
,
const
float
*
per_tensor_update_norm
,
const
float
learning_rate
,
const
float
learning_rate
,
const
float
decay
,
bool
use_nvlamb
)
{
const
float
decay
,
bool
use_nvlamb
)
{
// I'd like this kernel to propagate infs/nans.
// I'd like this kernel to propagate infs/nans.
// if(*noop_gmem == 1)
// if(*noop_gmem == 1)
// return;
// return;
...
@@ -311,7 +310,8 @@ void multi_tensor_lamb_cuda(int chunk_size, at::Tensor noop_flag,
...
@@ -311,7 +310,8 @@ void multi_tensor_lamb_cuda(int chunk_size, at::Tensor noop_flag,
// Handle grad averaging mode
// Handle grad averaging mode
float
beta3
=
1.0
f
;
float
beta3
=
1.0
f
;
if
(
grad_averaging
==
1
)
beta3
=
1
-
beta1
;
if
(
grad_averaging
==
1
)
beta3
=
1
-
beta1
;
std
::
vector
<
std
::
vector
<
at
::
Tensor
>>
grad_list
(
tensor_lists
.
begin
(),
std
::
vector
<
std
::
vector
<
at
::
Tensor
>>
grad_list
(
tensor_lists
.
begin
(),
tensor_lists
.
begin
()
+
1
);
tensor_lists
.
begin
()
+
1
);
...
@@ -330,7 +330,7 @@ void multi_tensor_lamb_cuda(int chunk_size, at::Tensor noop_flag,
...
@@ -330,7 +330,7 @@ void multi_tensor_lamb_cuda(int chunk_size, at::Tensor noop_flag,
tensor_lists
[
0
][
0
].
scalar_type
(),
0
,
"lamb_stage_1"
,
tensor_lists
[
0
][
0
].
scalar_type
(),
0
,
"lamb_stage_1"
,
multi_tensor_apply
<
4
>
(
BLOCK_SIZE
,
chunk_size
,
noop_flag
,
tensor_lists
,
multi_tensor_apply
<
4
>
(
BLOCK_SIZE
,
chunk_size
,
noop_flag
,
tensor_lists
,
LAMBStage1Functor
<
scalar_t_0
>
(),
beta1
,
beta2
,
LAMBStage1Functor
<
scalar_t_0
>
(),
beta1
,
beta2
,
beta3
,
// 1-beta1 or 1 depends on averaging mode
beta3
,
// 1-beta1 or 1 depends on averaging mode
bias_correction1
,
bias_correction2
,
epsilon
,
bias_correction1
,
bias_correction2
,
epsilon
,
(
adamMode_t
)
mode
,
weight_decay
,
(
adamMode_t
)
mode
,
weight_decay
,
global_grad_norm
.
DATA_PTR
<
float
>
(),
max_grad_norm
);)
global_grad_norm
.
DATA_PTR
<
float
>
(),
max_grad_norm
);)
...
...
colossalai/kernel/cuda_native/csrc/multi_tensor_scale_kernel.cu
View file @
58580b50
...
@@ -15,8 +15,7 @@
...
@@ -15,8 +15,7 @@
#define BLOCK_SIZE 512
#define BLOCK_SIZE 512
#define ILP 4
#define ILP 4
template
<
typename
T
>
template
<
typename
T
>
__device__
__forceinline__
bool
is_aligned
(
T
*
p
)
{
__device__
__forceinline__
bool
is_aligned
(
T
*
p
)
{
return
((
uint64_t
)
p
)
%
(
ILP
*
sizeof
(
T
))
==
0
;
return
((
uint64_t
)
p
)
%
(
ILP
*
sizeof
(
T
))
==
0
;
}
}
...
@@ -28,8 +27,7 @@ __device__ __forceinline__ void load_store(T *dst, T *src, int dst_offset,
...
@@ -28,8 +27,7 @@ __device__ __forceinline__ void load_store(T *dst, T *src, int dst_offset,
((
LT
*
)
dst
)[
dst_offset
]
=
((
LT
*
)
src
)[
src_offset
];
((
LT
*
)
dst
)[
dst_offset
]
=
((
LT
*
)
src
)[
src_offset
];
}
}
template
<
typename
in_t
,
typename
out_t
>
template
<
typename
in_t
,
typename
out_t
>
struct
ScaleFunctor
{
struct
ScaleFunctor
{
__device__
__forceinline__
void
operator
()(
int
chunk_size
,
__device__
__forceinline__
void
operator
()(
int
chunk_size
,
volatile
int
*
noop_gmem
,
volatile
int
*
noop_gmem
,
TensorListMetadata
<
2
>
&
tl
,
TensorListMetadata
<
2
>
&
tl
,
...
@@ -78,7 +76,8 @@ struct ScaleFunctor {
...
@@ -78,7 +76,8 @@ struct ScaleFunctor {
for
(
int
ii
=
0
;
ii
<
ILP
;
ii
++
)
{
for
(
int
ii
=
0
;
ii
<
ILP
;
ii
++
)
{
r_in
[
ii
]
=
0
;
r_in
[
ii
]
=
0
;
int
i
=
i_start
+
threadIdx
.
x
+
ii
*
blockDim
.
x
;
int
i
=
i_start
+
threadIdx
.
x
+
ii
*
blockDim
.
x
;
if
(
i
<
n
&&
i
<
chunk_size
)
r_in
[
ii
]
=
in
[
i
];
if
(
i
<
n
&&
i
<
chunk_size
)
r_in
[
ii
]
=
in
[
i
];
}
}
// note for clarification to future michael:
// note for clarification to future michael:
// From a pure memory dependency perspective, there's likely no point
// From a pure memory dependency perspective, there's likely no point
...
@@ -94,13 +93,14 @@ struct ScaleFunctor {
...
@@ -94,13 +93,14 @@ struct ScaleFunctor {
#pragma unroll
#pragma unroll
for
(
int
ii
=
0
;
ii
<
ILP
;
ii
++
)
{
for
(
int
ii
=
0
;
ii
<
ILP
;
ii
++
)
{
int
i
=
i_start
+
threadIdx
.
x
+
ii
*
blockDim
.
x
;
int
i
=
i_start
+
threadIdx
.
x
+
ii
*
blockDim
.
x
;
if
(
i
<
n
&&
i
<
chunk_size
)
out
[
i
]
=
r_out
[
ii
];
if
(
i
<
n
&&
i
<
chunk_size
)
out
[
i
]
=
r_out
[
ii
];
}
}
}
}
}
}
if
(
!
finite
)
if
(
!
finite
)
*
noop_gmem
=
*
noop_gmem
=
1
;
// Blindly fire off a write. These will race but that's ok.
1
;
// Blindly fire off a write. These will race but that's ok.
}
}
};
};
...
...
colossalai/kernel/cuda_native/csrc/multi_tensor_sgd_kernel.cu
View file @
58580b50
// modified from
// modified from https://github.com/NVIDIA/apex/blob/master/csrc/multi_tensor_sgd_kernel.cu
// https://github.com/NVIDIA/apex/blob/master/csrc/multi_tensor_sgd_kernel.cu
#include <ATen/ATen.h>
#include <ATen/ATen.h>
#include <ATen/AccumulateType.h>
#include <ATen/AccumulateType.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/Exceptions.h>
#include <ATen/cuda/Exceptions.h>
#include "multi_tensor_apply.cuh"
#include "compat.h"
#include <assert.h>
#include <assert.h>
#include <cuda_runtime.h>
#include <cuda_runtime.h>
#include "compat.h"
#include "multi_tensor_apply.cuh"
#define BLOCK_SIZE 512
#define BLOCK_SIZE 512
#define ILP 4
#define ILP 4
...
@@ -29,53 +28,69 @@
...
@@ -29,53 +28,69 @@
* wd_after_momentum : apply weight decay _after_ momentum instead of before
* wd_after_momentum : apply weight decay _after_ momentum instead of before
**/
**/
template
<
int
N
,
typename
T_grad
,
typename
T_weight
>
template
<
int
N
,
typename
T_grad
,
typename
T_weight
>
struct
SGDFunctor
{
struct
SGDFunctor
__device__
__forceinline__
void
operator
()(
{
int
chunk_size
,
volatile
int
*
noop_gmem
,
TensorListMetadata
<
N
>
&
tl
,
__device__
__forceinline__
void
operator
()(
float
wd
,
float
momentum
,
float
dampening
,
float
lr
,
bool
nesterov
,
int
chunk_size
,
bool
first_run
,
bool
wd_after_momentum
,
float
scale
)
{
volatile
int
*
noop_gmem
,
// Early exit if we don't need to do anything
TensorListMetadata
<
N
>
&
tl
,
if
(
*
noop_gmem
)
return
;
float
wd
,
float
momentum
,
int
tensor_loc
=
tl
.
block_to_tensor
[
blockIdx
.
x
];
float
dampening
,
int
chunk_idx
=
tl
.
block_to_chunk
[
blockIdx
.
x
];
float
lr
,
int
n
=
tl
.
sizes
[
tensor_loc
];
bool
nesterov
,
bool
first_run
,
T_grad
*
grad_in
=
(
T_grad
*
)
tl
.
addresses
[
0
][
tensor_loc
];
bool
wd_after_momentum
,
grad_in
+=
chunk_idx
*
chunk_size
;
float
scale
)
{
T_weight
*
weight_in
=
(
T_weight
*
)
tl
.
addresses
[
1
][
tensor_loc
];
// Early exit if we don't need to do anything
weight_in
+=
chunk_idx
*
chunk_size
;
if
(
*
noop_gmem
)
return
;
T_weight
*
mom_in
=
(
T_weight
*
)
tl
.
addresses
[
2
][
tensor_loc
];
mom_in
+=
chunk_idx
*
chunk_size
;
at
::
Half
*
model_weights_out
=
nullptr
;
if
(
N
==
4
)
{
model_weights_out
=
(
at
::
Half
*
)
tl
.
addresses
[
3
][
tensor_loc
];
model_weights_out
+=
chunk_idx
*
chunk_size
;
}
n
-=
chunk_idx
*
chunk_size
;
int
tensor_loc
=
tl
.
block_to_tensor
[
blockIdx
.
x
];
int
chunk_idx
=
tl
.
block_to_chunk
[
blockIdx
.
x
];
int
n
=
tl
.
sizes
[
tensor_loc
];
// Non-divergent exit condition for the __syncthreads
T_grad
*
grad_in
=
(
T_grad
*
)
tl
.
addresses
[
0
][
tensor_loc
];
float
incoming_grads
[
ILP
];
grad_in
+=
chunk_idx
*
chunk_size
;
float
incoming_weights
[
ILP
];
float
incoming_moms
[
ILP
];
T_weight
*
weight_in
=
(
T_weight
*
)
tl
.
addresses
[
1
][
tensor_loc
];
for
(
int
i_start
=
0
;
i_start
<
n
&&
i_start
<
chunk_size
;
weight_in
+=
chunk_idx
*
chunk_size
;
i_start
+=
blockDim
.
x
*
ILP
)
{
#pragma unroll
T_weight
*
mom_in
=
(
T_weight
*
)
tl
.
addresses
[
2
][
tensor_loc
];
for
(
int
ii
=
0
;
ii
<
ILP
;
ii
++
)
{
mom_in
+=
chunk_idx
*
chunk_size
;
incoming_grads
[
ii
]
=
0
;
incoming_weights
[
ii
]
=
0
;
at
::
Half
*
model_weights_out
=
nullptr
;
incoming_moms
[
ii
]
=
0
;
if
(
N
==
4
)
int
i
=
i_start
+
threadIdx
.
x
+
ii
*
blockDim
.
x
;
{
if
(
i
<
n
&&
i
<
chunk_size
)
{
model_weights_out
=
(
at
::
Half
*
)
tl
.
addresses
[
3
][
tensor_loc
];
incoming_grads
[
ii
]
=
static_cast
<
float
>
(
grad_in
[
i
])
*
scale
;
model_weights_out
+=
chunk_idx
*
chunk_size
;
incoming_weights
[
ii
]
=
static_cast
<
float
>
(
weight_in
[
i
]);
incoming_moms
[
ii
]
=
static_cast
<
float
>
(
mom_in
[
i
]);
}
}
}
n
-=
chunk_idx
*
chunk_size
;
// Non-divergent exit condition for the __syncthreads
float
incoming_grads
[
ILP
];
float
incoming_weights
[
ILP
];
float
incoming_moms
[
ILP
];
for
(
int
i_start
=
0
;
i_start
<
n
&&
i_start
<
chunk_size
;
i_start
+=
blockDim
.
x
*
ILP
)
{
#pragma unroll
for
(
int
ii
=
0
;
ii
<
ILP
;
ii
++
)
{
incoming_grads
[
ii
]
=
0
;
incoming_weights
[
ii
]
=
0
;
incoming_moms
[
ii
]
=
0
;
int
i
=
i_start
+
threadIdx
.
x
+
ii
*
blockDim
.
x
;
if
(
i
<
n
&&
i
<
chunk_size
)
{
incoming_grads
[
ii
]
=
static_cast
<
float
>
(
grad_in
[
i
])
*
scale
;
incoming_weights
[
ii
]
=
static_cast
<
float
>
(
weight_in
[
i
]);
incoming_moms
[
ii
]
=
static_cast
<
float
>
(
mom_in
[
i
]);
}
}
// note for clarification to future michael:
// note for clarification to future michael:
// From a pure memory dependency perspective, there's likely no point unrolling
// From a pure memory dependency perspective, there's likely no point unrolling
...
@@ -83,128 +98,185 @@ struct SGDFunctor {
...
@@ -83,128 +98,185 @@ struct SGDFunctor {
// Put another way, the STGs are dependent on the LDGs, but not on each other.
// Put another way, the STGs are dependent on the LDGs, but not on each other.
// There is still compute ILP benefit from unrolling the loop though.
// There is still compute ILP benefit from unrolling the loop though.
#pragma unroll
#pragma unroll
for
(
int
ii
=
0
;
ii
<
ILP
;
ii
++
)
{
for
(
int
ii
=
0
;
ii
<
ILP
;
ii
++
)
int
i
=
i_start
+
threadIdx
.
x
+
ii
*
blockDim
.
x
;
{
if
(
i
<
n
&&
i
<
chunk_size
)
{
int
i
=
i_start
+
threadIdx
.
x
+
ii
*
blockDim
.
x
;
// apply weight decay before momentum if necessary
if
(
i
<
n
&&
i
<
chunk_size
)
if
(
wd
!=
0.
f
&&
!
wd_after_momentum
)
{
incoming_grads
[
ii
]
+=
wd
*
incoming_weights
[
ii
];
// apply weight decay before momentum if necessary
if
(
wd
!=
0.
f
&&
!
wd_after_momentum
)
if
(
momentum
!=
0.
f
)
{
incoming_grads
[
ii
]
+=
wd
*
incoming_weights
[
ii
];
if
(
!
first_run
)
incoming_moms
[
ii
]
=
incoming_moms
[
ii
]
*
momentum
+
if
(
momentum
!=
0.
f
)
(
1.
f
-
dampening
)
*
incoming_grads
[
ii
];
{
else
// initialize momentums to current incoming grads
if
(
!
first_run
)
incoming_moms
[
ii
]
=
incoming_grads
[
ii
];
incoming_moms
[
ii
]
=
incoming_moms
[
ii
]
*
momentum
+
(
1.
f
-
dampening
)
*
incoming_grads
[
ii
];
else
// initialize momentums to current incoming grads
if
(
nesterov
)
incoming_moms
[
ii
]
=
incoming_grads
[
ii
];
incoming_grads
[
ii
]
+=
momentum
*
incoming_moms
[
ii
];
else
if
(
nesterov
)
incoming_grads
[
ii
]
=
incoming_moms
[
ii
];
incoming_grads
[
ii
]
+=
momentum
*
incoming_moms
[
ii
];
}
else
incoming_grads
[
ii
]
=
incoming_moms
[
ii
];
// Apply WD after momentum if desired
}
if
(
wd
!=
0.
f
&&
wd_after_momentum
)
incoming_grads
[
ii
]
+=
wd
*
incoming_weights
[
ii
];
// Apply WD after momentum if desired
if
(
wd
!=
0.
f
&&
wd_after_momentum
)
// adjust the weight and write out
incoming_grads
[
ii
]
+=
wd
*
incoming_weights
[
ii
];
weight_in
[
i
]
+=
(
-
lr
*
incoming_grads
[
ii
]);
// adjust the weight and write out
// if necessary, write out an fp16 copy of the weights
weight_in
[
i
]
+=
(
-
lr
*
incoming_grads
[
ii
]);
if
(
N
==
4
)
model_weights_out
[
i
]
=
static_cast
<
at
::
Half
>
(
weight_in
[
i
]);
// if necessary, write out an fp16 copy of the weights
if
(
N
==
4
)
// also write out the new momentum
model_weights_out
[
i
]
=
static_cast
<
at
::
Half
>
(
weight_in
[
i
]);
if
(
momentum
!=
0.
f
)
mom_in
[
i
]
=
incoming_moms
[
ii
];
// also write out the new momentum
if
(
momentum
!=
0.
f
)
mom_in
[
i
]
=
incoming_moms
[
ii
];
}
}
}
}
}
}
}
}
};
};
void
multi_tensor_sgd_cuda
(
int
chunk_size
,
at
::
Tensor
noop_flag
,
void
multi_tensor_sgd_cuda
(
std
::
vector
<
std
::
vector
<
at
::
Tensor
>>
tensor_lists
,
int
chunk_size
,
float
wd
,
float
momentum
,
float
dampening
,
float
lr
,
at
::
Tensor
noop_flag
,
bool
nesterov
,
bool
first_run
,
std
::
vector
<
std
::
vector
<
at
::
Tensor
>>
tensor_lists
,
bool
wd_after_momentum
,
float
scale
)
{
float
wd
,
auto
num_tensors
=
tensor_lists
.
size
();
float
momentum
,
auto
grad_type
=
tensor_lists
[
0
][
0
].
scalar_type
();
float
dampening
,
auto
weight_type
=
tensor_lists
[
1
][
0
].
scalar_type
();
float
lr
,
bool
nesterov
,
if
(
num_tensors
==
4
)
bool
first_run
,
for
(
int
i
=
0
;
i
<
tensor_lists
[
3
].
size
();
i
++
)
bool
wd_after_momentum
,
TORCH_CHECK
(
tensor_lists
[
3
][
i
].
scalar_type
()
==
at
::
ScalarType
::
Half
,
float
scale
)
"Additional output tensors should always be fp16."
);
{
auto
num_tensors
=
tensor_lists
.
size
();
TORCH_CHECK
(
noop_flag
.
device
()
==
tensor_lists
[
0
][
0
].
device
(),
auto
grad_type
=
tensor_lists
[
0
][
0
].
scalar_type
();
"expected noop flag to be on the same device as tensors"
);
auto
weight_type
=
tensor_lists
[
1
][
0
].
scalar_type
();
// We have 3 possibilities to handle here, in terms of
if
(
num_tensors
==
4
)
// grad_type, param_type, momentum_type, requires_fp16_copy
for
(
int
i
=
0
;
i
<
tensor_lists
[
3
].
size
();
i
++
)
// 1. fp16, fp16, fp16, No
TORCH_CHECK
(
tensor_lists
[
3
][
i
].
scalar_type
()
==
at
::
ScalarType
::
Half
,
// 2. fp32, fp32, fp32, No
"Additional output tensors should always be fp16."
);
// 3. fp16, fp32, fp32, Yes
// 4. fp32, fp32, fp32, Yes // this is the materialize_master_grads=True case
TORCH_CHECK
(
noop_flag
.
device
()
==
tensor_lists
[
0
][
0
].
device
(),
"expected noop flag to be on the same device as tensors"
);
// It's easier to hardcode these possibilities than to use
// switches etc. to handle the cross-product of cases where
// We have 3 possibilities to handle here, in terms of
// we don't want the majority of them.
// grad_type, param_type, momentum_type, requires_fp16_copy
// 1. fp16, fp16, fp16, No
// Case 1. fp16, fp16, fp16, No
// 2. fp32, fp32, fp32, No
if
(
grad_type
==
at
::
ScalarType
::
Half
&&
// 3. fp16, fp32, fp32, Yes
weight_type
==
at
::
ScalarType
::
Half
&&
num_tensors
==
3
)
{
// 4. fp32, fp32, fp32, Yes // this is the materialize_master_grads=True case
multi_tensor_apply
<
3
>
(
BLOCK_SIZE
,
chunk_size
,
noop_flag
,
tensor_lists
,
// It's easier to hardcode these possibilities than to use
SGDFunctor
<
3
,
at
::
Half
,
at
::
Half
>
(),
wd
,
momentum
,
// switches etc. to handle the cross-product of cases where
dampening
,
lr
,
nesterov
,
first_run
,
wd_after_momentum
,
// we don't want the majority of them.
scale
);
}
// Case 1. fp16, fp16, fp16, No
// Case 2. fp16, fp32, fp32, No
if
(
grad_type
==
at
::
ScalarType
::
Half
&&
// else if (grad_type == at::ScalarType::Half &&
weight_type
==
at
::
ScalarType
::
Half
&&
// weight_type == at::ScalarType::Float &&
num_tensors
==
3
)
// num_tensors == 3) {
{
// multi_tensor_apply<3>(
multi_tensor_apply
<
3
>
(
// BLOCK_SIZE,
BLOCK_SIZE
,
// chunk_size,
chunk_size
,
// noop_flag,
noop_flag
,
// tensor_lists,
tensor_lists
,
// SGDFunctor<3, at::Half, float>(),
SGDFunctor
<
3
,
at
::
Half
,
at
::
Half
>
(),
// wd,
wd
,
// momentum,
momentum
,
// dampening,
dampening
,
// lr,
lr
,
// nesterov,
nesterov
,
// first_run,
first_run
,
// wd_after_momentum);
wd_after_momentum
,
// }
scale
);
// Case 2. fp32, fp32, fp32, No
}
else
if
(
grad_type
==
at
::
ScalarType
::
Float
&&
// Case 2. fp16, fp32, fp32, No
weight_type
==
at
::
ScalarType
::
Float
&&
num_tensors
==
3
)
{
// else if (grad_type == at::ScalarType::Half &&
multi_tensor_apply
<
3
>
(
BLOCK_SIZE
,
chunk_size
,
noop_flag
,
tensor_lists
,
// weight_type == at::ScalarType::Float &&
SGDFunctor
<
3
,
float
,
float
>
(),
wd
,
momentum
,
// num_tensors == 3) {
dampening
,
lr
,
nesterov
,
first_run
,
wd_after_momentum
,
// multi_tensor_apply<3>(
scale
);
// BLOCK_SIZE,
}
// chunk_size,
// Case 3. fp16, fp32, fp32, Yes
// noop_flag,
else
if
(
grad_type
==
at
::
ScalarType
::
Half
&&
// tensor_lists,
weight_type
==
at
::
ScalarType
::
Float
&&
num_tensors
==
4
)
{
// SGDFunctor<3, at::Half, float>(),
multi_tensor_apply
<
4
>
(
BLOCK_SIZE
,
chunk_size
,
noop_flag
,
tensor_lists
,
// wd,
SGDFunctor
<
4
,
at
::
Half
,
float
>
(),
wd
,
momentum
,
// momentum,
dampening
,
lr
,
nesterov
,
first_run
,
wd_after_momentum
,
// dampening,
scale
);
// lr,
}
// nesterov,
// Case 4. fp32, fp32, fp32, Yes
// first_run,
else
if
(
grad_type
==
at
::
ScalarType
::
Float
&&
// wd_after_momentum);
weight_type
==
at
::
ScalarType
::
Float
&&
num_tensors
==
4
)
{
// }
multi_tensor_apply
<
4
>
(
BLOCK_SIZE
,
chunk_size
,
noop_flag
,
tensor_lists
,
// Case 2. fp32, fp32, fp32, No
SGDFunctor
<
4
,
float
,
float
>
(),
wd
,
momentum
,
else
if
(
grad_type
==
at
::
ScalarType
::
Float
&&
dampening
,
lr
,
nesterov
,
first_run
,
wd_after_momentum
,
weight_type
==
at
::
ScalarType
::
Float
&&
scale
);
num_tensors
==
3
)
}
else
{
{
AT_ERROR
(
multi_tensor_apply
<
3
>
(
"multi_tensor_sgd only supports some combinations of gradient & weight "
BLOCK_SIZE
,
"types. Given: "
,
chunk_size
,
"gradient: "
,
grad_type
,
", weight: "
,
weight_type
,
noop_flag
,
", num_lists: "
,
num_tensors
);
tensor_lists
,
}
SGDFunctor
<
3
,
float
,
float
>
(),
wd
,
AT_CUDA_CHECK
(
cudaGetLastError
());
momentum
,
dampening
,
lr
,
nesterov
,
first_run
,
wd_after_momentum
,
scale
);
}
// Case 3. fp16, fp32, fp32, Yes
else
if
(
grad_type
==
at
::
ScalarType
::
Half
&&
weight_type
==
at
::
ScalarType
::
Float
&&
num_tensors
==
4
)
{
multi_tensor_apply
<
4
>
(
BLOCK_SIZE
,
chunk_size
,
noop_flag
,
tensor_lists
,
SGDFunctor
<
4
,
at
::
Half
,
float
>
(),
wd
,
momentum
,
dampening
,
lr
,
nesterov
,
first_run
,
wd_after_momentum
,
scale
);
}
// Case 4. fp32, fp32, fp32, Yes
else
if
(
grad_type
==
at
::
ScalarType
::
Float
&&
weight_type
==
at
::
ScalarType
::
Float
&&
num_tensors
==
4
)
{
multi_tensor_apply
<
4
>
(
BLOCK_SIZE
,
chunk_size
,
noop_flag
,
tensor_lists
,
SGDFunctor
<
4
,
float
,
float
>
(),
wd
,
momentum
,
dampening
,
lr
,
nesterov
,
first_run
,
wd_after_momentum
,
scale
);
}
else
{
AT_ERROR
(
"multi_tensor_sgd only supports some combinations of gradient & weight types. Given: "
,
"gradient: "
,
grad_type
,
", weight: "
,
weight_type
,
", num_lists: "
,
num_tensors
);
}
AT_CUDA_CHECK
(
cudaGetLastError
());
}
}
\ No newline at end of file
colossalai/kernel/cuda_native/csrc/multihead_attention_1d.cpp
View file @
58580b50
This diff is collapsed.
Click to expand it.
colossalai/kernel/cuda_native/csrc/multihead_attention_1d.h
View file @
58580b50
...
@@ -19,25 +19,21 @@
...
@@ -19,25 +19,21 @@
template
<
typename
T
>
template
<
typename
T
>
class
MultiHeadAttention
{
class
MultiHeadAttention
{
public:
public:
MultiHeadAttention
(
int
layer_id
,
int
max_batch_tokens
,
int
_max_seq_len
,
MultiHeadAttention
(
int
layer_id
,
int
max_batch_tokens
,
int
_max_seq_len
,
int
hidden_size
,
int
hidden_size
,
int
num_heads
,
float
attn_dropout_ratio
,
int
num_heads
,
float
attn_dropout_ratio
,
float
hidden_output_dropout_ratio
,
float
hidden_output_dropout_ratio
,
bool
pre_or_postLayerNorm
);
bool
pre_or_postLayerNorm
);
virtual
~
MultiHeadAttention
();
virtual
~
MultiHeadAttention
();
void
Forward
(
const
T
*
input_ptr
,
const
T
*
input_mask_ptr
,
T
*
out_ptr
);
void
Forward
(
const
T
*
input_ptr
,
const
T
*
input_mask_ptr
,
T
*
out_ptr
);
void
Backward
(
const
T
*
grad_output_ptr
,
const
T
*
input_ptr
,
void
Backward
(
const
T
*
grad_output_ptr
,
const
T
*
input_ptr
,
const
T
*
output_ptr
,
const
T
*
output_ptr
,
const
T
*
input_mask_ptr
,
const
T
*
input_mask_ptr
,
T
*
grad_input_ptr
);
T
*
grad_input_ptr
);
void
attn_layer_fw
(
const
T
*
input_ptr
,
const
T
*
input_mask_ptr
,
T
*
output_ptr
,
void
attn_layer_fw
(
const
T
*
input_ptr
,
const
T
*
input_mask_ptr
,
T
*
output_ptr
,
T
*
buffer
);
T
*
buffer
);
void
attn_layer_bw
(
const
T
*
input_ptr
,
const
T
*
input_mask_ptr
,
void
attn_layer_bw
(
const
T
*
input_ptr
,
const
T
*
input_mask_ptr
,
const
T
*
output_ptr
,
const
T
*
output_ptr
,
const
T
*
grad_output_ptr
,
const
T
*
grad_output_ptr
,
T
*
grad_input_attn_layer_bwptr
,
T
*
buffer
);
T
*
grad_input_attn_layer_bwptr
,
T
*
buffer
);
void
set_cur_batch_shape
(
int
batch_size
,
int
seq_len
)
{
void
set_cur_batch_shape
(
int
batch_size
,
int
seq_len
)
{
_batch_size
=
batch_size
;
_batch_size
=
batch_size
;
...
@@ -87,17 +83,14 @@ class MultiHeadAttention {
...
@@ -87,17 +83,14 @@ class MultiHeadAttention {
}
}
_qkv_ptr
=
cuda_malloc
<
T
>
(
_max_batch_tokens
*
_hidden_size
*
3
);
_qkv_ptr
=
cuda_malloc
<
T
>
(
_max_batch_tokens
*
_hidden_size
*
3
);
_soft_out_ptr
=
_soft_out_ptr
=
cuda_malloc
<
T
>
(
_max_batch_tokens
*
_heads
/
pg_size
*
_max_seq_len
);
cuda_malloc
<
T
>
(
_max_batch_tokens
*
_heads
/
pg_size
*
_max_seq_len
);
_ctx_bufB_ptr
=
cuda_malloc
<
T
>
(
_max_batch_tokens
*
_heads
/
pg_size
*
_max_seq_len
);
_ctx_bufB_ptr
=
cuda_malloc
<
T
>
(
_max_batch_tokens
*
_heads
/
pg_size
*
_max_seq_len
);
_attn_o_inp_ptr
=
cuda_malloc
<
T
>
(
_max_batch_tokens
*
_hidden_size
);
_attn_o_inp_ptr
=
cuda_malloc
<
T
>
(
_max_batch_tokens
*
_hidden_size
);
// buffer size needed by attn bw
// buffer size needed by attn bw
size_t
smem_size
=
size_t
smem_size
=
4
*
_max_batch_tokens
*
_hidden_size
/
pg_size
+
4
*
_max_batch_tokens
*
_hidden_size
/
pg_size
+
std
::
max
(
3
*
_max_batch_tokens
*
_hidden_size
/
pg_size
,
std
::
max
(
3
*
_max_batch_tokens
*
_hidden_size
/
pg_size
,
_max_batch_tokens
*
_heads
/
pg_size
*
_max_seq_len
);
_max_batch_tokens
*
_heads
/
pg_size
*
_max_seq_len
);
if
(
!
_shared_mem_ptr
)
{
if
(
!
_shared_mem_ptr
)
{
cuda_free
(
_shared_mem_ptr
);
cuda_free
(
_shared_mem_ptr
);
...
...
colossalai/kernel/cuda_native/csrc/scaled_masked_softmax_cuda.cu
View file @
58580b50
...
@@ -2,13 +2,12 @@
...
@@ -2,13 +2,12 @@
* with minor changes. */
* with minor changes. */
#include <ATen/ATen.h>
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <cuda.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <cuda_fp16.h>
#include <cuda_fp16.h>
#include <cuda_profiler_api.h>
#include <cuda_profiler_api.h>
#include <
cuda_runtime
.h>
#include <
ATen/cuda/CUDAContext
.h>
#include <torch/extension.h>
#include <torch/extension.h>
#include "scaled_masked_softmax.h"
#include "scaled_masked_softmax.h"
#include "type_shim.h"
#include "type_shim.h"
...
@@ -16,15 +15,17 @@ namespace multihead_attn {
...
@@ -16,15 +15,17 @@ namespace multihead_attn {
namespace
fused_softmax
{
namespace
fused_softmax
{
namespace
scaled_masked_softmax
{
namespace
scaled_masked_softmax
{
int
get_batch_per_block_cuda
(
int
query_seq_len
,
int
key_seq_len
,
int
batches
,
int
get_batch_per_block_cuda
(
int
query_seq_len
,
int
key_seq_len
,
int
batches
,
int
attn_heads
){
int
attn_heads
)
{
return
get_batch_per_block
(
query_seq_len
,
key_seq_len
,
batches
,
attn_heads
);
return
get_batch_per_block
(
query_seq_len
,
key_seq_len
,
batches
,
attn_heads
);
}
}
torch
::
Tensor
fwd_cuda
(
torch
::
Tensor
const
&
input
,
torch
::
Tensor
const
&
mask
,
float
scale_factor
)
{
torch
::
Tensor
fwd_cuda
(
// input is a 4d tensor with dimensions [batches, attn_heads, seq_len,
torch
::
Tensor
const
&
input
,
// seq_len]
torch
::
Tensor
const
&
mask
,
float
scale_factor
)
{
// input is a 4d tensor with dimensions [batches, attn_heads, seq_len, seq_len]
const
int
batches
=
input
.
size
(
0
);
const
int
batches
=
input
.
size
(
0
);
const
int
pad_batches
=
mask
.
size
(
0
);
const
int
pad_batches
=
mask
.
size
(
0
);
const
int
attn_heads
=
input
.
size
(
1
);
const
int
attn_heads
=
input
.
size
(
1
);
...
@@ -37,10 +38,10 @@ torch::Tensor fwd_cuda(torch::Tensor const& input, torch::Tensor const& mask,
...
@@ -37,10 +38,10 @@ torch::Tensor fwd_cuda(torch::Tensor const& input, torch::Tensor const& mask,
TORCH_INTERNAL_ASSERT
(
mask
.
size
(
2
)
==
query_seq_len
);
TORCH_INTERNAL_ASSERT
(
mask
.
size
(
2
)
==
query_seq_len
);
TORCH_INTERNAL_ASSERT
(
mask
.
size
(
3
)
==
key_seq_len
);
TORCH_INTERNAL_ASSERT
(
mask
.
size
(
3
)
==
key_seq_len
);
// Output
// Output
auto
act_options
=
input
.
options
().
requires_grad
(
false
);
auto
act_options
=
input
.
options
().
requires_grad
(
false
);
torch
::
Tensor
softmax_results
=
torch
::
empty
(
torch
::
Tensor
softmax_results
=
{
batches
,
attn_heads
,
query_seq_len
,
key_seq_len
},
act_options
);
torch
::
empty
(
{
batches
,
attn_heads
,
query_seq_len
,
key_seq_len
},
act_options
);
// Softmax Intermediate Result Ptr
// Softmax Intermediate Result Ptr
void
*
input_ptr
=
static_cast
<
void
*>
(
input
.
data_ptr
());
void
*
input_ptr
=
static_cast
<
void
*>
(
input
.
data_ptr
());
...
@@ -48,23 +49,31 @@ torch::Tensor fwd_cuda(torch::Tensor const& input, torch::Tensor const& mask,
...
@@ -48,23 +49,31 @@ torch::Tensor fwd_cuda(torch::Tensor const& input, torch::Tensor const& mask,
void
*
softmax_results_ptr
=
static_cast
<
void
*>
(
softmax_results
.
data_ptr
());
void
*
softmax_results_ptr
=
static_cast
<
void
*>
(
softmax_results
.
data_ptr
());
DISPATCH_HALF_AND_BFLOAT
(
DISPATCH_HALF_AND_BFLOAT
(
input
.
scalar_type
(),
"dispatch_scaled_masked_softmax_forward"
,
input
.
scalar_type
(),
"dispatch_scaled_masked_softmax_forward"
,
dispatch_scaled_masked_softmax_forward
<
scalar_t
,
scalar_t
,
float
>
(
dispatch_scaled_masked_softmax_forward
<
scalar_t
,
scalar_t
,
float
>
(
reinterpret_cast
<
scalar_t
*>
(
softmax_results_ptr
),
reinterpret_cast
<
scalar_t
*>
(
softmax_results_ptr
),
reinterpret_cast
<
const
scalar_t
*>
(
input_ptr
),
reinterpret_cast
<
const
scalar_t
*>
(
input_ptr
),
reinterpret_cast
<
const
uint8_t
*>
(
mask_ptr
),
scale_factor
,
reinterpret_cast
<
const
uint8_t
*>
(
mask_ptr
),
query_seq_len
,
key_seq_len
,
batches
,
attn_heads
,
pad_batches
););
scale_factor
,
query_seq_len
,
key_seq_len
,
batches
,
attn_heads
,
pad_batches
);
);
return
softmax_results
;
return
softmax_results
;
}
}
torch
::
Tensor
bwd_cuda
(
torch
::
Tensor
const
&
output_grads_
,
torch
::
Tensor
bwd_cuda
(
torch
::
Tensor
const
&
softmax_results_
,
torch
::
Tensor
const
&
output_grads_
,
float
scale_factor
)
{
torch
::
Tensor
const
&
softmax_results_
,
float
scale_factor
)
{
auto
output_grads
=
output_grads_
.
contiguous
();
auto
output_grads
=
output_grads_
.
contiguous
();
auto
softmax_results
=
softmax_results_
.
contiguous
();
auto
softmax_results
=
softmax_results_
.
contiguous
();
// output grads is a 4d tensor with dimensions [batches, attn_heads, seq_len,
//output grads is a 4d tensor with dimensions [batches, attn_heads, seq_len, seq_len]
// seq_len]
const
int
batches
=
output_grads
.
size
(
0
);
const
int
batches
=
output_grads
.
size
(
0
);
const
int
attn_heads
=
output_grads
.
size
(
1
);
const
int
attn_heads
=
output_grads
.
size
(
1
);
const
int
query_seq_len
=
output_grads
.
size
(
2
);
const
int
query_seq_len
=
output_grads
.
size
(
2
);
...
@@ -72,18 +81,24 @@ torch::Tensor bwd_cuda(torch::Tensor const& output_grads_,
...
@@ -72,18 +81,24 @@ torch::Tensor bwd_cuda(torch::Tensor const& output_grads_,
void
*
output_grads_ptr
=
static_cast
<
void
*>
(
output_grads
.
data_ptr
());
void
*
output_grads_ptr
=
static_cast
<
void
*>
(
output_grads
.
data_ptr
());
//
Softmax Grad
//Softmax Grad
DISPATCH_HALF_AND_BFLOAT
(
DISPATCH_HALF_AND_BFLOAT
(
output_grads_
.
scalar_type
(),
"dispatch_scaled_masked_softmax_backward"
,
output_grads_
.
scalar_type
(),
"dispatch_scaled_masked_softmax_backward"
,
dispatch_scaled_masked_softmax_backward
<
scalar_t
,
scalar_t
,
float
>
(
dispatch_scaled_masked_softmax_backward
<
scalar_t
,
scalar_t
,
float
>
(
reinterpret_cast
<
scalar_t
*>
(
output_grads_ptr
),
reinterpret_cast
<
scalar_t
*>
(
output_grads_ptr
),
reinterpret_cast
<
scalar_t
*>
(
output_grads_ptr
),
reinterpret_cast
<
scalar_t
*>
(
output_grads_ptr
),
reinterpret_cast
<
scalar_t
const
*>
(
softmax_results
.
data_ptr
()),
reinterpret_cast
<
scalar_t
const
*>
(
softmax_results
.
data_ptr
()),
scale_factor
,
query_seq_len
,
key_seq_len
,
batches
,
attn_heads
););
scale_factor
,
query_seq_len
,
// backward pass is completely in-place
key_seq_len
,
batches
,
attn_heads
);
);
//backward pass is completely in-place
return
output_grads
;
return
output_grads
;
}
}
}
// namespace scaled_masked_softmax
}
}
// namespace fused_softmax
}
}
// namespace multihead_attn
}
colossalai/kernel/cuda_native/csrc/scaled_upper_triang_masked_softmax.cpp
View file @
58580b50
...
@@ -3,52 +3,57 @@
...
@@ -3,52 +3,57 @@
#include <cuda_fp16.h>
#include <cuda_fp16.h>
#include <torch/extension.h>
#include <torch/extension.h>
#include <vector>
#include <vector>
namespace
multihead_attn
{
namespace
multihead_attn
{
namespace
fused_softmax
{
namespace
fused_softmax
{
namespace
scaled_upper_triang_masked_softmax
{
namespace
scaled_upper_triang_masked_softmax
{
torch
::
Tensor
fwd_cuda
(
torch
::
Tensor
const
&
input
,
float
scale_factor
);
torch
::
Tensor
fwd_cuda
(
torch
::
Tensor
const
&
input
,
float
scale_factor
);
torch
::
Tensor
bwd_cuda
(
torch
::
Tensor
const
&
output_grads
,
torch
::
Tensor
bwd_cuda
(
torch
::
Tensor
const
&
softmax_results
,
torch
::
Tensor
const
&
output_grads
,
float
scale_factor
);
torch
::
Tensor
const
&
softmax_results
,
float
scale_factor
);
torch
::
Tensor
fwd
(
torch
::
Tensor
const
&
input
,
float
scale_factor
)
{
torch
::
Tensor
fwd
(
torch
::
Tensor
const
&
input
,
float
scale_factor
)
{
AT_ASSERTM
(
input
.
dim
()
==
3
,
"expected 3D tensor"
);
AT_ASSERTM
(
input
.
dim
()
==
3
,
"expected 3D tensor"
);
AT_ASSERTM
((
input
.
scalar_type
()
==
at
::
ScalarType
::
Half
)
||
AT_ASSERTM
((
input
.
scalar_type
()
==
at
::
ScalarType
::
Half
)
||
(
input
.
scalar_type
()
==
at
::
ScalarType
::
BFloat16
),
(
input
.
scalar_type
()
==
at
::
ScalarType
::
BFloat16
),
"Only fp16 and bf16 are supported"
);
"Only fp16 and bf16 are supported"
);
return
fwd_cuda
(
input
,
scale_factor
);
return
fwd_cuda
(
input
,
scale_factor
);
}
}
torch
::
Tensor
bwd
(
torch
::
Tensor
const
&
output_grads
,
torch
::
Tensor
bwd
(
torch
::
Tensor
const
&
softmax_results
,
float
scale_factor
)
{
torch
::
Tensor
const
&
output_grads
,
torch
::
Tensor
const
&
softmax_results
,
float
scale_factor
)
{
AT_ASSERTM
(
output_grads
.
dim
()
==
3
,
"expected 3D tensor"
);
AT_ASSERTM
(
output_grads
.
dim
()
==
3
,
"expected 3D tensor"
);
AT_ASSERTM
(
softmax_results
.
dim
()
==
3
,
"expected 3D tensor"
);
AT_ASSERTM
(
softmax_results
.
dim
()
==
3
,
"expected 3D tensor"
);
AT_ASSERTM
((
output_grads
.
scalar_type
()
==
at
::
ScalarType
::
Half
)
||
AT_ASSERTM
((
output_grads
.
scalar_type
()
==
at
::
ScalarType
::
Half
)
||
(
output_grads
.
scalar_type
()
==
at
::
ScalarType
::
BFloat16
),
(
output_grads
.
scalar_type
()
==
at
::
ScalarType
::
BFloat16
),
"Only fp16 and bf16 are supported"
);
"Only fp16 and bf16 are supported"
);
AT_ASSERTM
((
softmax_results
.
scalar_type
()
==
at
::
ScalarType
::
Half
)
||
AT_ASSERTM
((
softmax_results
.
scalar_type
()
==
at
::
ScalarType
::
Half
)
||
(
softmax_results
.
scalar_type
()
==
at
::
ScalarType
::
BFloat16
),
(
softmax_results
.
scalar_type
()
==
at
::
ScalarType
::
BFloat16
),
"Only fp16 and bf16 are supported"
);
"Only fp16 and bf16 are supported"
);
return
bwd_cuda
(
output_grads
,
softmax_results
,
scale_factor
);
return
bwd_cuda
(
output_grads
,
softmax_results
,
scale_factor
);
}
}
}
// end namespace scaled_upper_triang_masked_softmax
}
// end namespace scaled_upper_triang_masked_softmax
}
// end namespace fused_softmax
}
// end namespace fused_softmax
}
// end namespace multihead_attn
}
// end namespace multihead_attn
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"forward"
,
m
.
def
(
"forward"
,
&
multihead_attn
::
fused_softmax
::
scaled_upper_triang_masked_softmax
::
fwd
,
&
multihead_attn
::
fused_softmax
::
scaled_upper_triang_masked_softmax
::
fwd
,
"Self Multihead Attention scaled, time masked softmax -- Forward."
);
"Self Multihead Attention scaled, time masked softmax -- Forward."
);
m
.
def
(
"backward"
,
m
.
def
(
"backward"
,
&
multihead_attn
::
fused_softmax
::
scaled_upper_triang_masked_softmax
::
bwd
,
&
multihead_attn
::
fused_softmax
::
scaled_upper_triang_masked_softmax
::
bwd
,
"Self Multihead Attention scaled, time masked softmax -- Backward."
);
"Self Multihead Attention scaled, time masked softmax -- Backward."
);
}
}
colossalai/kernel/cuda_native/csrc/scaled_upper_triang_masked_softmax_cuda.cu
View file @
58580b50
...
@@ -2,13 +2,12 @@
...
@@ -2,13 +2,12 @@
* with minor changes. */
* with minor changes. */
#include <ATen/ATen.h>
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <cuda.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <cuda_fp16.h>
#include <cuda_fp16.h>
#include <cuda_profiler_api.h>
#include <cuda_profiler_api.h>
#include <
cuda_runtime
.h>
#include <
ATen/cuda/CUDAContext
.h>
#include <torch/extension.h>
#include <torch/extension.h>
#include "scaled_upper_triang_masked_softmax.h"
#include "scaled_upper_triang_masked_softmax.h"
#include "type_shim.h"
#include "type_shim.h"
...
@@ -16,15 +15,18 @@ namespace multihead_attn {
...
@@ -16,15 +15,18 @@ namespace multihead_attn {
namespace
fused_softmax
{
namespace
fused_softmax
{
namespace
scaled_upper_triang_masked_softmax
{
namespace
scaled_upper_triang_masked_softmax
{
torch
::
Tensor
fwd_cuda
(
torch
::
Tensor
const
&
input
,
float
scale_factor
)
{
torch
::
Tensor
fwd_cuda
(
torch
::
Tensor
const
&
input
,
float
scale_factor
)
{
// input is a 3d tensor with dimensions [attn_batches, seq_len, seq_len]
// input is a 3d tensor with dimensions [attn_batches, seq_len, seq_len]
const
int
attn_batches
=
input
.
size
(
0
);
const
int
attn_batches
=
input
.
size
(
0
);
const
int
seq_len
=
input
.
size
(
1
);
const
int
seq_len
=
input
.
size
(
1
);
TORCH_INTERNAL_ASSERT
(
seq_len
<=
2048
);
TORCH_INTERNAL_ASSERT
(
seq_len
<=
2048
);
// Output
// Output
auto
act_options
=
input
.
options
().
requires_grad
(
false
);
auto
act_options
=
input
.
options
().
requires_grad
(
false
);
torch
::
Tensor
softmax_results
=
torch
::
Tensor
softmax_results
=
torch
::
empty
({
attn_batches
,
seq_len
,
seq_len
},
act_options
);
torch
::
empty
({
attn_batches
,
seq_len
,
seq_len
},
act_options
);
// Softmax Intermediate Result Ptr
// Softmax Intermediate Result Ptr
...
@@ -34,42 +36,50 @@ torch::Tensor fwd_cuda(torch::Tensor const& input, float scale_factor) {
...
@@ -34,42 +36,50 @@ torch::Tensor fwd_cuda(torch::Tensor const& input, float scale_factor) {
DISPATCH_HALF_AND_BFLOAT
(
DISPATCH_HALF_AND_BFLOAT
(
input
.
scalar_type
(),
input
.
scalar_type
(),
"dispatch_scaled_upper_triang_masked_softmax_forward"
,
"dispatch_scaled_upper_triang_masked_softmax_forward"
,
dispatch_scaled_upper_triang_masked_softmax_forward
<
scalar_t
,
scalar_t
,
dispatch_scaled_upper_triang_masked_softmax_forward
<
scalar_t
,
scalar_t
,
float
>
(
float
>
(
reinterpret_cast
<
scalar_t
*>
(
softmax_results_ptr
),
reinterpret_cast
<
scalar_t
*>
(
softmax_results_ptr
),
reinterpret_cast
<
const
scalar_t
*>
(
input_ptr
),
reinterpret_cast
<
const
scalar_t
*>
(
input_ptr
),
scale_factor
,
seq_len
,
scale_factor
,
seq_len
,
attn_batches
););
seq_len
,
seq_len
,
attn_batches
);
);
return
softmax_results
;
return
softmax_results
;
}
}
torch
::
Tensor
bwd_cuda
(
torch
::
Tensor
const
&
output_grads_
,
torch
::
Tensor
bwd_cuda
(
torch
::
Tensor
const
&
softmax_results_
,
torch
::
Tensor
const
&
output_grads_
,
float
scale_factor
)
{
torch
::
Tensor
const
&
softmax_results_
,
float
scale_factor
)
{
auto
output_grads
=
output_grads_
.
contiguous
();
auto
output_grads
=
output_grads_
.
contiguous
();
auto
softmax_results
=
softmax_results_
.
contiguous
();
auto
softmax_results
=
softmax_results_
.
contiguous
();
// output grads is a 3d tensor with dimensions [attn_batches, seq_len,
//output grads is a 3d tensor with dimensions [attn_batches, seq_len, seq_len]
// seq_len]
const
int
attn_batches
=
output_grads
.
size
(
0
);
const
int
attn_batches
=
output_grads
.
size
(
0
);
const
int
seq_len
=
output_grads
.
size
(
1
);
const
int
seq_len
=
output_grads
.
size
(
1
);
TORCH_INTERNAL_ASSERT
(
output_grads
.
size
(
1
)
==
output_grads
.
size
(
2
));
TORCH_INTERNAL_ASSERT
(
output_grads
.
size
(
1
)
==
output_grads
.
size
(
2
));
void
*
output_grads_ptr
=
static_cast
<
void
*>
(
output_grads
.
data_ptr
());
void
*
output_grads_ptr
=
static_cast
<
void
*>
(
output_grads
.
data_ptr
());
//
Softmax Grad
//Softmax Grad
DISPATCH_HALF_AND_BFLOAT
(
DISPATCH_HALF_AND_BFLOAT
(
output_grads_
.
scalar_type
(),
output_grads_
.
scalar_type
(),
"dispatch_scaled_upper_triang_masked_softmax_backward"
,
"dispatch_scaled_upper_triang_masked_softmax_backward"
,
dispatch_scaled_upper_triang_masked_softmax_backward
<
scalar_t
,
scalar_t
,
dispatch_scaled_upper_triang_masked_softmax_backward
<
scalar_t
,
scalar_t
,
float
>
(
float
>
(
reinterpret_cast
<
scalar_t
*>
(
output_grads_ptr
),
reinterpret_cast
<
scalar_t
*>
(
output_grads_ptr
),
reinterpret_cast
<
scalar_t
*>
(
output_grads_ptr
),
reinterpret_cast
<
scalar_t
*>
(
output_grads_ptr
),
reinterpret_cast
<
scalar_t
const
*>
(
softmax_results
.
data_ptr
()),
reinterpret_cast
<
scalar_t
const
*>
(
softmax_results
.
data_ptr
()),
scale_factor
,
scale_factor
,
seq_len
,
seq_len
,
attn_batches
););
seq_len
,
seq_len
,
// backward pass is completely in-place
attn_batches
);
);
//backward pass is completely in-place
return
output_grads
;
return
output_grads
;
}
}
}
// namespace scaled_upper_triang_masked_softmax
}
}
// namespace fused_softmax
}
}
// namespace multihead_attn
}
colossalai/kernel/cuda_native/layer_norm.py
View file @
58580b50
...
@@ -24,8 +24,8 @@ class FusedLayerNormAffineFunction(torch.autograd.Function):
...
@@ -24,8 +24,8 @@ class FusedLayerNormAffineFunction(torch.autograd.Function):
input_
=
input
.
contiguous
()
input_
=
input
.
contiguous
()
weight_
=
weight
.
contiguous
()
weight_
=
weight
.
contiguous
()
bias_
=
bias
.
contiguous
()
bias_
=
bias
.
contiguous
()
output
,
mean
,
invvar
=
colossal_layer_norm_cuda
.
forward_affine
(
input_
,
ctx
.
normalized_shape
,
weight_
,
bias_
,
output
,
mean
,
invvar
=
colossal_layer_norm_cuda
.
forward_affine
(
ctx
.
eps
)
input_
,
ctx
.
normalized_shape
,
weight_
,
bias_
,
ctx
.
eps
)
ctx
.
save_for_backward
(
input_
,
weight_
,
bias_
,
mean
,
invvar
)
ctx
.
save_for_backward
(
input_
,
weight_
,
bias_
,
mean
,
invvar
)
return
output
return
output
...
@@ -72,7 +72,8 @@ class MixedFusedLayerNorm(torch.nn.Module):
...
@@ -72,7 +72,8 @@ class MixedFusedLayerNorm(torch.nn.Module):
def
forward
(
self
,
input
):
def
forward
(
self
,
input
):
return
FusedLayerNormAffineFunction
.
apply
(
input
,
self
.
weight
,
self
.
bias
,
self
.
normalized_shape
,
self
.
eps
)
return
FusedLayerNormAffineFunction
.
apply
(
input
,
self
.
weight
,
self
.
bias
,
self
.
normalized_shape
,
self
.
eps
)
def
__repr__
(
self
):
def
__repr__
(
self
):
return
f
'MixedFusedLayerNorm(normalized_shape=
{
self
.
normalized_shape
}
, eps=
{
self
.
eps
}
)'
return
f
'MixedFusedLayerNorm(normalized_shape=
{
self
.
normalized_shape
}
, eps=
{
self
.
eps
}
)'
colossalai/kernel/cuda_native/scaled_softmax.py
View file @
58580b50
...
@@ -28,7 +28,9 @@ class ScaledUpperTriangMaskedSoftmax(torch.autograd.Function):
...
@@ -28,7 +28,9 @@ class ScaledUpperTriangMaskedSoftmax(torch.autograd.Function):
raise
RuntimeError
(
'ScaledUpperTriangMaskedSoftmax requires cuda extensions'
)
raise
RuntimeError
(
'ScaledUpperTriangMaskedSoftmax requires cuda extensions'
)
scale_t
=
torch
.
tensor
([
scale
])
scale_t
=
torch
.
tensor
([
scale
])
softmax_results
=
colossal_scaled_upper_triang_masked_softmax
.
forward
(
inputs
,
scale_t
[
0
])
softmax_results
=
colossal_scaled_upper_triang_masked_softmax
.
forward
(
inputs
,
scale_t
[
0
]
)
ctx
.
save_for_backward
(
softmax_results
,
scale_t
)
ctx
.
save_for_backward
(
softmax_results
,
scale_t
)
return
softmax_results
return
softmax_results
...
@@ -41,7 +43,9 @@ class ScaledUpperTriangMaskedSoftmax(torch.autograd.Function):
...
@@ -41,7 +43,9 @@ class ScaledUpperTriangMaskedSoftmax(torch.autograd.Function):
raise
RuntimeError
(
'ScaledUpperTriangMaskedSoftmax requires cuda extensions'
)
raise
RuntimeError
(
'ScaledUpperTriangMaskedSoftmax requires cuda extensions'
)
softmax_results
,
scale_t
=
ctx
.
saved_tensors
softmax_results
,
scale_t
=
ctx
.
saved_tensors
input_grads
=
colossal_scaled_upper_triang_masked_softmax
.
backward
(
output_grads
,
softmax_results
,
scale_t
[
0
])
input_grads
=
colossal_scaled_upper_triang_masked_softmax
.
backward
(
output_grads
,
softmax_results
,
scale_t
[
0
]
)
return
input_grads
,
None
return
input_grads
,
None
...
@@ -77,7 +81,9 @@ class ScaledMaskedSoftmax(torch.autograd.Function):
...
@@ -77,7 +81,9 @@ class ScaledMaskedSoftmax(torch.autograd.Function):
softmax_results
,
scale_t
=
ctx
.
saved_tensors
softmax_results
,
scale_t
=
ctx
.
saved_tensors
input_grads
=
colossal_scaled_masked_softmax
.
backward
(
output_grads
,
softmax_results
,
scale_t
[
0
])
input_grads
=
colossal_scaled_masked_softmax
.
backward
(
output_grads
,
softmax_results
,
scale_t
[
0
]
)
return
input_grads
,
None
,
None
return
input_grads
,
None
,
None
...
@@ -108,8 +114,9 @@ class FusedScaleMaskSoftmax(nn.Module):
...
@@ -108,8 +114,9 @@ class FusedScaleMaskSoftmax(nn.Module):
super
(
FusedScaleMaskSoftmax
,
self
).
__init__
()
super
(
FusedScaleMaskSoftmax
,
self
).
__init__
()
self
.
input_in_fp16
=
input_in_fp16
self
.
input_in_fp16
=
input_in_fp16
self
.
input_in_bf16
=
input_in_bf16
self
.
input_in_bf16
=
input_in_bf16
assert
not
(
self
.
input_in_fp16
assert
not
(
and
self
.
input_in_bf16
),
"both fp16 and bf16 flags cannot be active at the same time."
self
.
input_in_fp16
and
self
.
input_in_bf16
),
"both fp16 and bf16 flags cannot be active at the same time."
self
.
input_in_float16
=
self
.
input_in_fp16
or
self
.
input_in_bf16
self
.
input_in_float16
=
self
.
input_in_fp16
or
self
.
input_in_bf16
self
.
attn_mask_type
=
attn_mask_type
self
.
attn_mask_type
=
attn_mask_type
self
.
scaled_masked_softmax_fusion
=
scaled_masked_softmax_fusion
self
.
scaled_masked_softmax_fusion
=
scaled_masked_softmax_fusion
...
@@ -117,7 +124,9 @@ class FusedScaleMaskSoftmax(nn.Module):
...
@@ -117,7 +124,9 @@ class FusedScaleMaskSoftmax(nn.Module):
self
.
softmax_in_fp32
=
softmax_in_fp32
self
.
softmax_in_fp32
=
softmax_in_fp32
self
.
scale
=
scale
self
.
scale
=
scale
assert
(
self
.
scale
is
None
or
softmax_in_fp32
),
"softmax should be in fp32 when scaled"
assert
(
self
.
scale
is
None
or
softmax_in_fp32
),
"softmax should be in fp32 when scaled"
def
forward
(
self
,
input
,
mask
):
def
forward
(
self
,
input
,
mask
):
# [b, np, sq, sk]
# [b, np, sq, sk]
...
@@ -131,13 +140,14 @@ class FusedScaleMaskSoftmax(nn.Module):
...
@@ -131,13 +140,14 @@ class FusedScaleMaskSoftmax(nn.Module):
def
is_kernel_available
(
self
,
mask
,
b
,
np
,
sq
,
sk
):
def
is_kernel_available
(
self
,
mask
,
b
,
np
,
sq
,
sk
):
attn_batches
=
b
*
np
attn_batches
=
b
*
np
if
(
self
.
scaled_masked_softmax_fusion
# user want to fuse
if
(
and
self
.
input_in_float16
# input must be fp16
self
.
scaled_masked_softmax_fusion
# user want to fuse
and
mask
is
not
None
# mask tensor must not be None
and
self
.
input_in_float16
# input must be fp16
and
16
<
sk
<=
2048
# sk must be 16 ~ 2048
and
mask
is
not
None
# mask tensor must not be None
and
sq
%
4
==
0
# sq must be divisor of 4
and
16
<
sk
<=
2048
# sk must be 16 ~ 2048
and
attn_batches
%
4
==
0
# np * b must be divisor of 4
and
sq
%
4
==
0
# sq must be divisor of 4
):
and
attn_batches
%
4
==
0
# np * b must be divisor of 4
):
if
0
<=
sk
<=
2048
:
if
0
<=
sk
<=
2048
:
batch_per_block
=
self
.
get_batch_per_block
(
sq
,
sk
,
b
,
np
)
batch_per_block
=
self
.
get_batch_per_block
(
sq
,
sk
,
b
,
np
)
...
...
colossalai/kernel/jit/bias_gelu.py
View file @
58580b50
import
torch
import
torch
###### BIAS GELU FUSION/ NO AUTOGRAD ################
###### BIAS GELU FUSION/ NO AUTOGRAD ################
# 1/sqrt(2*pi)-> 0.3989423
# 1/sqrt(2*pi)-> 0.3989423
# 1/sqrt(2) -> 0.70710678
# 1/sqrt(2) -> 0.70710678
...
@@ -8,12 +9,10 @@ import torch
...
@@ -8,12 +9,10 @@ import torch
# actual gelu is:
# actual gelu is:
# x * 0.5 * (1.0 + torch.erf(x * 0.70710678))
# x * 0.5 * (1.0 + torch.erf(x * 0.70710678))
@
torch
.
jit
.
script
@
torch
.
jit
.
script
def
bias_gelu
(
bias
,
y
):
def
bias_gelu
(
bias
,
y
):
x
=
bias
+
y
x
=
bias
+
y
return
x
*
0.5
*
(
1.0
+
torch
.
tanh
(
0.79788456
*
x
*
(
1
+
0.044715
*
x
*
x
)))
return
x
*
0.5
*
(
1.0
+
torch
.
tanh
(
0.79788456
*
x
*
(
1
+
0.044715
*
x
*
x
)))
# gradient of tanh approximation of gelu
# gradient of tanh approximation of gelu
# gradient of actual gelu is:
# gradient of actual gelu is:
...
@@ -24,11 +23,9 @@ def bias_gelu_back(g, bias, y):
...
@@ -24,11 +23,9 @@ def bias_gelu_back(g, bias, y):
tanh_out
=
torch
.
tanh
(
0.79788456
*
x
*
(
1
+
0.044715
*
x
*
x
))
tanh_out
=
torch
.
tanh
(
0.79788456
*
x
*
(
1
+
0.044715
*
x
*
x
))
# sqrt(2/pi) * 3 * 0.044715 -> 0.1070322243
# sqrt(2/pi) * 3 * 0.044715 -> 0.1070322243
ff
=
0.5
*
x
*
((
1
-
tanh_out
*
tanh_out
)
*
(
0.79788456
+
0.1070322243
*
x
*
x
))
+
0.5
*
(
1
+
tanh_out
)
ff
=
0.5
*
x
*
((
1
-
tanh_out
*
tanh_out
)
*
(
0.79788456
+
0.1070322243
*
x
*
x
))
+
0.5
*
(
1
+
tanh_out
)
return
ff
*
g
return
ff
*
g
class
GeLUFunction
(
torch
.
autograd
.
Function
):
class
GeLUFunction
(
torch
.
autograd
.
Function
):
@
staticmethod
@
staticmethod
# bias is an optional argument
# bias is an optional argument
def
forward
(
ctx
,
input
,
bias
):
def
forward
(
ctx
,
input
,
bias
):
...
@@ -41,5 +38,4 @@ class GeLUFunction(torch.autograd.Function):
...
@@ -41,5 +38,4 @@ class GeLUFunction(torch.autograd.Function):
tmp
=
bias_gelu_back
(
grad_output
,
bias
,
input
)
tmp
=
bias_gelu_back
(
grad_output
,
bias
,
input
)
return
tmp
,
tmp
return
tmp
,
tmp
bias_gelu_impl
=
GeLUFunction
.
apply
bias_gelu_impl
=
GeLUFunction
.
apply
\ No newline at end of file
colossalai/nn/layer/parallel_2d/layers.py
View file @
58580b50
...
@@ -182,7 +182,7 @@ class Linear2D(ParallelLayer):
...
@@ -182,7 +182,7 @@ class Linear2D(ParallelLayer):
def
forward
(
self
,
x
:
Tensor
)
->
Tensor
:
def
forward
(
self
,
x
:
Tensor
)
->
Tensor
:
# input: [m/q, n/q, k/q]
# input: [m/q, n/q, k/q]
# output: [m/q, n/q, h/q]
# output: [m/q, n/q, h/q]
out_shape
=
x
.
shape
[:
-
1
]
+
(
self
.
hidden_size_per_partition
,)
out_shape
=
x
.
shape
[:
-
1
]
+
(
self
.
hidden_size_per_partition
,
)
output
=
Matmul_AB_2D
.
apply
(
x
,
self
.
weight
,
self
.
summa_dim
,
out_shape
,
self
.
row_rank
,
self
.
col_rank
,
output
=
Matmul_AB_2D
.
apply
(
x
,
self
.
weight
,
self
.
summa_dim
,
out_shape
,
self
.
row_rank
,
self
.
col_rank
,
ParallelMode
.
PARALLEL_2D_ROW
,
ParallelMode
.
PARALLEL_2D_COL
,
self
.
data_parallel_rank
,
ParallelMode
.
PARALLEL_2D_ROW
,
ParallelMode
.
PARALLEL_2D_COL
,
self
.
data_parallel_rank
,
...
@@ -337,16 +337,16 @@ class LayerNorm2D(ParallelLayer):
...
@@ -337,16 +337,16 @@ class LayerNorm2D(ParallelLayer):
def
forward
(
self
,
x
:
Tensor
)
->
Tensor
:
def
forward
(
self
,
x
:
Tensor
)
->
Tensor
:
with
torch
.
no_grad
():
with
torch
.
no_grad
():
E_x
=
torch
.
sum
(
x
,
dim
=-
1
,
keepdim
=
True
)
# [b/q, s, 1]
E_x
=
torch
.
sum
(
x
,
dim
=-
1
,
keepdim
=
True
)
# [b/q, s, 1]
torch
.
distributed
.
all_reduce
(
E_x
,
group
=
gpc
.
get_group
(
ParallelMode
.
PARALLEL_2D_ROW
))
torch
.
distributed
.
all_reduce
(
E_x
,
group
=
gpc
.
get_group
(
ParallelMode
.
PARALLEL_2D_ROW
))
E_x
/=
self
.
normalized_shape
E_x
/=
self
.
normalized_shape
# Var_x in the block below is the sum of input^2
# Var_x in the block below is the sum of input^2
Var_x
=
torch
.
sum
(
x
*
x
,
dim
=-
1
,
keepdim
=
True
)
# [b/q, s, 1]
Var_x
=
torch
.
sum
(
x
*
x
,
dim
=-
1
,
keepdim
=
True
)
# [b/q, s, 1]
torch
.
distributed
.
all_reduce
(
Var_x
,
group
=
gpc
.
get_group
(
ParallelMode
.
PARALLEL_2D_ROW
))
torch
.
distributed
.
all_reduce
(
Var_x
,
group
=
gpc
.
get_group
(
ParallelMode
.
PARALLEL_2D_ROW
))
Var_x
/=
self
.
normalized_shape
Var_x
/=
self
.
normalized_shape
Var_x
=
Var_x
-
E_x
*
E_x
# variance of x [b/q, s, 1]
Var_x
=
Var_x
-
E_x
*
E_x
# variance of x [b/q, s, 1]
# this time 1/sqrt(Var_x + epsilon)
# this time 1/sqrt(Var_x + epsilon)
Var_x
=
1.0
/
torch
.
sqrt
(
Var_x
+
self
.
variance_epsilon
)
Var_x
=
1.0
/
torch
.
sqrt
(
Var_x
+
self
.
variance_epsilon
)
...
@@ -569,7 +569,7 @@ class PatchEmbedding2D(ParallelLayer):
...
@@ -569,7 +569,7 @@ class PatchEmbedding2D(ParallelLayer):
output
=
F
.
conv2d
(
input_
,
weight
,
bias
,
stride
=
self
.
patch_size
)
output
=
F
.
conv2d
(
input_
,
weight
,
bias
,
stride
=
self
.
patch_size
)
if
self
.
flatten
:
if
self
.
flatten
:
output
=
output
.
flatten
(
2
).
transpose
(
1
,
2
)
# BCHW -> BNC
output
=
output
.
flatten
(
2
).
transpose
(
1
,
2
)
# BCHW -> BNC
cls_token
=
all_gather_tensor_2d
(
self
.
cls_token
,
-
1
,
ParallelMode
.
PARALLEL_2D_COL
)
cls_token
=
all_gather_tensor_2d
(
self
.
cls_token
,
-
1
,
ParallelMode
.
PARALLEL_2D_COL
)
pos_embed
=
all_gather_tensor_2d
(
self
.
pos_embed
,
-
1
,
ParallelMode
.
PARALLEL_2D_COL
)
pos_embed
=
all_gather_tensor_2d
(
self
.
pos_embed
,
-
1
,
ParallelMode
.
PARALLEL_2D_COL
)
...
@@ -1012,7 +1012,7 @@ class Classifier2D(ParallelLayer):
...
@@ -1012,7 +1012,7 @@ class Classifier2D(ParallelLayer):
destination
.
update
(
local_state
)
destination
.
update
(
local_state
)
def
forward
(
self
,
input_
:
Tensor
)
->
Tensor
:
def
forward
(
self
,
input_
:
Tensor
)
->
Tensor
:
out_shape
=
input_
.
shape
[:
-
1
]
+
(
self
.
num_classes
,)
out_shape
=
input_
.
shape
[:
-
1
]
+
(
self
.
num_classes
,
)
return
classifier_2d
(
input_
,
self
.
weight
,
self
.
bias
,
self
.
summa_dim
,
out_shape
,
self
.
row_rank
,
self
.
col_rank
,
return
classifier_2d
(
input_
,
self
.
weight
,
self
.
bias
,
self
.
summa_dim
,
out_shape
,
self
.
row_rank
,
self
.
col_rank
,
ParallelMode
.
PARALLEL_2D_ROW
,
ParallelMode
.
PARALLEL_2D_COL
,
self
.
data_parallel_rank
,
ParallelMode
.
PARALLEL_2D_ROW
,
ParallelMode
.
PARALLEL_2D_COL
,
self
.
data_parallel_rank
,
...
@@ -1186,7 +1186,7 @@ class VocabParallelClassifier2D(ParallelLayer):
...
@@ -1186,7 +1186,7 @@ class VocabParallelClassifier2D(ParallelLayer):
def
forward
(
self
,
x
:
Tensor
)
->
Tensor
:
def
forward
(
self
,
x
:
Tensor
)
->
Tensor
:
# input: [m/q, n/q, k/q]
# input: [m/q, n/q, k/q]
# output: [m/q, n/q, h/q]
# output: [m/q, n/q, h/q]
out_shape
=
x
.
shape
[:
-
1
]
+
(
self
.
output_size_per_partition
,)
out_shape
=
x
.
shape
[:
-
1
]
+
(
self
.
output_size_per_partition
,
)
output
=
Matmul_ABT_2D
.
apply
(
x
,
self
.
weight
,
self
.
summa_dim
,
out_shape
,
self
.
row_rank
,
self
.
col_rank
,
output
=
Matmul_ABT_2D
.
apply
(
x
,
self
.
weight
,
self
.
summa_dim
,
out_shape
,
self
.
row_rank
,
self
.
col_rank
,
ParallelMode
.
PARALLEL_2D_ROW
,
ParallelMode
.
PARALLEL_2D_COL
,
ParallelMode
.
PARALLEL_2D_ROW
,
ParallelMode
.
PARALLEL_2D_COL
,
...
...
colossalai/nn/layer/parallel_2p5d/layers.py
View file @
58580b50
...
@@ -189,7 +189,7 @@ class Linear2p5D(ParallelLayer):
...
@@ -189,7 +189,7 @@ class Linear2p5D(ParallelLayer):
def
forward
(
self
,
x
:
Tensor
)
->
Tensor
:
def
forward
(
self
,
x
:
Tensor
)
->
Tensor
:
# input: [m/dq, n/q, k/q]
# input: [m/dq, n/q, k/q]
# output: [m/dq, n/q, h/q]
# output: [m/dq, n/q, h/q]
out_shape
=
x
.
shape
[:
-
1
]
+
(
self
.
hidden_size_per_partition
,)
out_shape
=
x
.
shape
[:
-
1
]
+
(
self
.
hidden_size_per_partition
,
)
output
=
Matmul_AB_2p5D
.
apply
(
output
=
Matmul_AB_2p5D
.
apply
(
x
,
x
,
...
@@ -254,7 +254,7 @@ class LayerNorm2p5D(ParallelLayer):
...
@@ -254,7 +254,7 @@ class LayerNorm2p5D(ParallelLayer):
self
.
tesseract_dim
,
_
=
get_tesseract_dim_dep_from_env
()
self
.
tesseract_dim
,
_
=
get_tesseract_dim_dep_from_env
()
# partitioning dimension
# partitioning dimension
self
.
partitioned_partition
=
divide
(
normalized_shape
,
self
.
tesseract_dim
)
# *
self
.
partitioned_partition
=
divide
(
normalized_shape
,
self
.
tesseract_dim
)
# *
# create parameters
# create parameters
factory_kwargs
=
{
'device'
:
get_current_device
(),
'dtype'
:
dtype
}
factory_kwargs
=
{
'device'
:
get_current_device
(),
'dtype'
:
dtype
}
...
@@ -357,16 +357,16 @@ class LayerNorm2p5D(ParallelLayer):
...
@@ -357,16 +357,16 @@ class LayerNorm2p5D(ParallelLayer):
def
forward
(
self
,
x
:
Tensor
)
->
Tensor
:
def
forward
(
self
,
x
:
Tensor
)
->
Tensor
:
with
torch
.
no_grad
():
with
torch
.
no_grad
():
E_x
=
torch
.
sum
(
x
,
dim
=-
1
,
keepdim
=
True
)
# [b/q, s, 1]
E_x
=
torch
.
sum
(
x
,
dim
=-
1
,
keepdim
=
True
)
# [b/q, s, 1]
torch
.
distributed
.
all_reduce
(
E_x
,
group
=
gpc
.
get_group
(
ParallelMode
.
PARALLEL_2P5D_ROW
))
torch
.
distributed
.
all_reduce
(
E_x
,
group
=
gpc
.
get_group
(
ParallelMode
.
PARALLEL_2P5D_ROW
))
E_x
/=
self
.
normalized_shape
E_x
/=
self
.
normalized_shape
# Var_x in the block below is the sum of input^2
# Var_x in the block below is the sum of input^2
Var_x
=
torch
.
sum
(
x
*
x
,
dim
=-
1
,
keepdim
=
True
)
# [b/q, s, 1]
Var_x
=
torch
.
sum
(
x
*
x
,
dim
=-
1
,
keepdim
=
True
)
# [b/q, s, 1]
torch
.
distributed
.
all_reduce
(
Var_x
,
group
=
gpc
.
get_group
(
ParallelMode
.
PARALLEL_2P5D_ROW
))
torch
.
distributed
.
all_reduce
(
Var_x
,
group
=
gpc
.
get_group
(
ParallelMode
.
PARALLEL_2P5D_ROW
))
Var_x
/=
self
.
normalized_shape
Var_x
/=
self
.
normalized_shape
Var_x
=
Var_x
-
E_x
*
E_x
# variance of x [b/q, s, 1]
Var_x
=
Var_x
-
E_x
*
E_x
# variance of x [b/q, s, 1]
# this time 1/sqrt(Var_x + epsilon)
# this time 1/sqrt(Var_x + epsilon)
Var_x
=
1.0
/
torch
.
sqrt
(
Var_x
+
self
.
variance_epsilon
)
Var_x
=
1.0
/
torch
.
sqrt
(
Var_x
+
self
.
variance_epsilon
)
...
@@ -589,7 +589,7 @@ class PatchEmbedding2p5D(ParallelLayer):
...
@@ -589,7 +589,7 @@ class PatchEmbedding2p5D(ParallelLayer):
output
=
F
.
conv2d
(
input_
,
weight
,
bias
,
stride
=
self
.
patch_size
)
output
=
F
.
conv2d
(
input_
,
weight
,
bias
,
stride
=
self
.
patch_size
)
if
self
.
flatten
:
if
self
.
flatten
:
output
=
output
.
flatten
(
2
).
transpose
(
1
,
2
)
# BCHW -> BNC
output
=
output
.
flatten
(
2
).
transpose
(
1
,
2
)
# BCHW -> BNC
cls_token
=
all_gather_tensor_2p5d
(
self
.
cls_token
,
-
1
,
ParallelMode
.
PARALLEL_2P5D_COL
)
cls_token
=
all_gather_tensor_2p5d
(
self
.
cls_token
,
-
1
,
ParallelMode
.
PARALLEL_2P5D_COL
)
pos_embed
=
all_gather_tensor_2p5d
(
self
.
pos_embed
,
-
1
,
ParallelMode
.
PARALLEL_2P5D_COL
)
pos_embed
=
all_gather_tensor_2p5d
(
self
.
pos_embed
,
-
1
,
ParallelMode
.
PARALLEL_2P5D_COL
)
...
@@ -1038,7 +1038,7 @@ class Classifier2p5D(ParallelLayer):
...
@@ -1038,7 +1038,7 @@ class Classifier2p5D(ParallelLayer):
destination
.
update
(
local_state
)
destination
.
update
(
local_state
)
def
forward
(
self
,
input_
:
Tensor
)
->
Tensor
:
def
forward
(
self
,
input_
:
Tensor
)
->
Tensor
:
out_shape
=
input_
.
shape
[:
-
1
]
+
(
self
.
num_classes
,)
out_shape
=
input_
.
shape
[:
-
1
]
+
(
self
.
num_classes
,
)
return
classifier_2p5d
(
input_
,
self
.
weight
,
self
.
bias
,
self
.
tesseract_dim
,
out_shape
,
self
.
row_rank
,
return
classifier_2p5d
(
input_
,
self
.
weight
,
self
.
bias
,
self
.
tesseract_dim
,
out_shape
,
self
.
row_rank
,
self
.
col_rank
,
ParallelMode
.
PARALLEL_2P5D_ROW
,
ParallelMode
.
PARALLEL_2P5D_COL
,
self
.
col_rank
,
ParallelMode
.
PARALLEL_2P5D_ROW
,
ParallelMode
.
PARALLEL_2P5D_COL
,
...
@@ -1172,7 +1172,7 @@ class VocabParallelClassifier2p5D(ParallelLayer):
...
@@ -1172,7 +1172,7 @@ class VocabParallelClassifier2p5D(ParallelLayer):
def
forward
(
self
,
x
:
Tensor
)
->
Tensor
:
def
forward
(
self
,
x
:
Tensor
)
->
Tensor
:
# input: [m/dq, n/q, k/q]
# input: [m/dq, n/q, k/q]
# output: [m/dq, n/q, h/q]
# output: [m/dq, n/q, h/q]
out_shape
=
x
.
shape
[:
-
1
]
+
(
self
.
hidden_size_per_partition
,)
out_shape
=
x
.
shape
[:
-
1
]
+
(
self
.
hidden_size_per_partition
,
)
output
=
Matmul_ABT_2p5D
.
apply
(
output
=
Matmul_ABT_2p5D
.
apply
(
x
,
x
,
...
...
colossalai/nn/layer/parallel_3d/layers.py
View file @
58580b50
...
@@ -53,8 +53,8 @@ class LayerNorm3D(ParallelLayer):
...
@@ -53,8 +53,8 @@ class LayerNorm3D(ParallelLayer):
self
.
weight
=
Parameter
(
self
.
weight
=
Parameter
(
torch
.
ones
(
self
.
normalized_shape_per_partition
,
device
=
get_current_device
(),
dtype
=
dtype
))
torch
.
ones
(
self
.
normalized_shape_per_partition
,
device
=
get_current_device
(),
dtype
=
dtype
))
if
bias
:
if
bias
:
self
.
bias
=
Parameter
(
self
.
bias
=
Parameter
(
torch
.
zeros
(
self
.
normalized_shape_per_partition
,
torch
.
zeros
(
self
.
normalized_shape_per_partition
,
device
=
get_current_device
(),
dtype
=
dtype
))
device
=
get_current_device
(),
dtype
=
dtype
))
else
:
else
:
self
.
bias
=
None
self
.
bias
=
None
self
.
variance_epsilon
=
eps
self
.
variance_epsilon
=
eps
...
@@ -854,7 +854,7 @@ class PatchEmbedding3D(ParallelLayer):
...
@@ -854,7 +854,7 @@ class PatchEmbedding3D(ParallelLayer):
input_
=
split_tensor_3d
(
input_
,
0
,
self
.
input_parallel_mode
)
input_
=
split_tensor_3d
(
input_
,
0
,
self
.
input_parallel_mode
)
output
=
F
.
conv2d
(
input_
,
self
.
weight
,
self
.
bias
,
stride
=
self
.
patch_size
)
output
=
F
.
conv2d
(
input_
,
self
.
weight
,
self
.
bias
,
stride
=
self
.
patch_size
)
if
self
.
flatten
:
if
self
.
flatten
:
output
=
output
.
flatten
(
2
).
transpose
(
1
,
2
)
# BCHW -> BNC
output
=
output
.
flatten
(
2
).
transpose
(
1
,
2
)
# BCHW -> BNC
cls_token
=
self
.
cls_token
.
expand
(
output
.
shape
[
0
],
-
1
,
-
1
)
cls_token
=
self
.
cls_token
.
expand
(
output
.
shape
[
0
],
-
1
,
-
1
)
output
=
torch
.
cat
((
cls_token
,
output
),
dim
=
1
)
output
=
torch
.
cat
((
cls_token
,
output
),
dim
=
1
)
...
...
colossalai/nn/layer/utils/common.py
View file @
58580b50
...
@@ -13,8 +13,7 @@ from torch import Tensor, nn
...
@@ -13,8 +13,7 @@ from torch import Tensor, nn
class
CheckpointModule
(
nn
.
Module
):
class
CheckpointModule
(
nn
.
Module
):
def
__init__
(
self
,
checkpoint
:
bool
=
True
,
offload
:
bool
=
False
):
def
__init__
(
self
,
checkpoint
:
bool
=
True
,
offload
:
bool
=
False
):
super
().
__init__
()
super
().
__init__
()
self
.
checkpoint
=
checkpoint
self
.
checkpoint
=
checkpoint
self
.
_use_checkpoint
=
checkpoint
self
.
_use_checkpoint
=
checkpoint
...
@@ -79,7 +78,6 @@ def get_tensor_parallel_mode():
...
@@ -79,7 +78,6 @@ def get_tensor_parallel_mode():
def
_ntuple
(
n
):
def
_ntuple
(
n
):
def
parse
(
x
):
def
parse
(
x
):
if
isinstance
(
x
,
collections
.
abc
.
Iterable
):
if
isinstance
(
x
,
collections
.
abc
.
Iterable
):
return
x
return
x
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment