Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
deepspeed
Commits
eadbbe09
Commit
eadbbe09
authored
Apr 25, 2021
by
401qingkong
Browse files
push rocm deepspeed v0.3.13
parent
ab5534fc
Changes
155
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
5403 additions
and
0 deletions
+5403
-0
deepspeed/ops/csrc/includes/hip/general_kernels.h
deepspeed/ops/csrc/includes/hip/general_kernels.h
+47
-0
deepspeed/ops/csrc/includes/hip/normalize_layer.h
deepspeed/ops/csrc/includes/hip/normalize_layer.h
+202
-0
deepspeed/ops/csrc/includes/hip/softmax.h
deepspeed/ops/csrc/includes/hip/softmax.h
+60
-0
deepspeed/ops/csrc/includes/hip/strided_batch_gemm.h
deepspeed/ops/csrc/includes/hip/strided_batch_gemm.h
+179
-0
deepspeed/ops/csrc/includes/hip/type_shim.h
deepspeed/ops/csrc/includes/hip/type_shim.h
+110
-0
deepspeed/ops/csrc/includes/normalize_layer.h
deepspeed/ops/csrc/includes/normalize_layer.h
+202
-0
deepspeed/ops/csrc/includes/softmax.h
deepspeed/ops/csrc/includes/softmax.h
+60
-0
deepspeed/ops/csrc/includes/strided_batch_gemm.h
deepspeed/ops/csrc/includes/strided_batch_gemm.h
+179
-0
deepspeed/ops/csrc/includes/type_shim.h
deepspeed/ops/csrc/includes/type_shim.h
+109
-0
deepspeed/ops/csrc/lamb/fused_lamb_cuda.cpp
deepspeed/ops/csrc/lamb/fused_lamb_cuda.cpp
+109
-0
deepspeed/ops/csrc/lamb/fused_lamb_cuda_kernel.cu
deepspeed/ops/csrc/lamb/fused_lamb_cuda_kernel.cu
+469
-0
deepspeed/ops/csrc/lamb/hip/fused_lamb_hip.cpp
deepspeed/ops/csrc/lamb/hip/fused_lamb_hip.cpp
+109
-0
deepspeed/ops/csrc/lamb/hip/fused_lamb_hip_kernel.hip
deepspeed/ops/csrc/lamb/hip/fused_lamb_hip_kernel.hip
+469
-0
deepspeed/ops/csrc/sparse_attention/hip/utils.cpp
deepspeed/ops/csrc/sparse_attention/hip/utils.cpp
+120
-0
deepspeed/ops/csrc/sparse_attention/utils.cpp
deepspeed/ops/csrc/sparse_attention/utils.cpp
+120
-0
deepspeed/ops/csrc/transformer/cublas_wrappers.cu
deepspeed/ops/csrc/transformer/cublas_wrappers.cu
+199
-0
deepspeed/ops/csrc/transformer/dropout_kernels.cu
deepspeed/ops/csrc/transformer/dropout_kernels.cu
+868
-0
deepspeed/ops/csrc/transformer/ds_transformer_cuda.cpp
deepspeed/ops/csrc/transformer/ds_transformer_cuda.cpp
+1046
-0
deepspeed/ops/csrc/transformer/gelu_kernels.cu
deepspeed/ops/csrc/transformer/gelu_kernels.cu
+335
-0
deepspeed/ops/csrc/transformer/general_kernels.cu
deepspeed/ops/csrc/transformer/general_kernels.cu
+411
-0
No files found.
deepspeed/ops/csrc/includes/hip/general_kernels.h
0 → 100644
View file @
eadbbe09
#include <hip/hip_runtime.h>
#include <hip/hip_fp16.h>
#include <stdio.h>
#include <stdlib.h>
#include <cooperative_groups.h>
#include <hiprand_kernel.h>
#include "context.h"
#include "cublas_wrappers.h"
#define THREADS 256
#define TILE_DIM 32
#define minus_infinity -1 * std::numeric_limits<float>::infinity()
#define FINAL_MASK 0xffffffff
template
<
typename
T
>
void
launch_fused_add2
(
T
*
out
,
const
T
*
inp1
,
const
T
*
inp2
,
int
batch_size
,
int
seq_length
,
int
hidden_size
,
hipStream_t
&
stream
);
template
<
typename
T
>
void
launch_fused_add4
(
T
*
out
,
const
T
*
inp1
,
const
T
*
inp2
,
const
T
*
inp3
,
const
T
*
inp4
,
int
batch_size
,
int
seq_length
,
int
hidden_size
,
hipStream_t
&
stream
);
template
<
typename
T
>
void
launch_fused_add3
(
T
*
out
,
const
T
*
inp1
,
const
T
*
inp2
,
const
T
*
inp3
,
int
batch_size
,
int
seq_length
,
int
hidden_size
,
hipStream_t
&
stream
);
deepspeed/ops/csrc/includes/hip/normalize_layer.h
0 → 100644
View file @
eadbbe09
#pragma once
#include <hip/hip_runtime.h>
#include <hip/hip_fp16.h>
#include <stdio.h>
#include <fstream>
#include "custom_cuda_layers.h"
using
namespace
std
;
template
<
typename
T
>
class
Normalize_Layer
{
public:
struct
Config
{
uint32_t
batchSize
;
uint32_t
seqLength
;
uint32_t
hiddenDim
;
float
epsilon
;
bool
training
;
bool
useMean
;
Config
(
uint32_t
batch
,
uint32_t
seq
,
uint32_t
h
,
float
epsilon
=
1e-12
,
bool
training
=
true
,
bool
useMean
=
true
)
:
batchSize
(
batch
),
seqLength
(
seq
),
hiddenDim
(
h
),
epsilon
(
epsilon
),
training
(
training
),
useMean
(
useMean
)
{
}
};
Normalize_Layer
(
Config
config
)
:
config_
(
config
),
vars
(
nullptr
),
means
(
nullptr
),
vals_hat
(
nullptr
)
{
}
~
Normalize_Layer
()
{}
void
ForwardCheckpoint
(
int
bsz
,
// batch * seq
T
*
vals
,
const
T
*
residual
,
const
T
*
gamma
,
const
T
*
betta
,
hipStream_t
&
stream
,
bool
preLayerNorm
=
false
)
{
launch_bias_residual_layer_norm
(
vals
,
residual
,
gamma
,
betta
,
config_
.
epsilon
,
bsz
,
config_
.
hiddenDim
,
stream
,
preLayerNorm
,
config_
.
training
,
vars
,
means
);
}
void
Forward
(
int
bsz
,
T
*
vals
,
const
T
*
residual
,
const
T
*
gamma
,
const
T
*
betta
,
hipStream_t
&
stream
,
bool
preLayerNorm
=
false
)
{
launch_bias_residual_layer_norm
(
vals
,
residual
,
gamma
,
betta
,
config_
.
epsilon
,
bsz
,
config_
.
hiddenDim
,
stream
,
preLayerNorm
,
config_
.
training
,
vars
);
}
void
Backward
(
int
bsz
,
const
T
*
out_grad
,
const
T
*
gamma
,
T
*
gamma_grad
,
T
*
betta_grad
,
hipStream_t
stream
[
2
],
T
*
inp_grad_out
,
const
T
*
norm_in
=
nullptr
)
{
launch_layerNorm_backward
(
out_grad
,
norm_in
,
vars
,
means
,
gamma
,
gamma_grad
,
betta_grad
,
inp_grad_out
,
bsz
,
config_
.
hiddenDim
,
stream
);
}
void
Backward
(
int
bsz
,
const
T
*
out_grad
,
const
T
*
gamma
,
const
T
*
betta
,
T
*
gamma_grad
,
T
*
betta_grad
,
hipStream_t
stream
[
2
],
T
*
inp_grad_out
,
const
T
*
norm_out
)
{
launch_layerNorm_backward
(
out_grad
,
norm_out
,
vars
,
gamma
,
gamma_grad
,
betta_grad
,
inp_grad_out
,
bsz
,
config_
.
hiddenDim
,
stream
,
!
config_
.
useMean
,
betta
);
}
void
BackwardFusedAdd
(
int
bsz
,
const
T
*
out_grad1
,
const
T
*
out_grad2
,
const
T
*
gamma
,
T
*
gamma_grad
,
T
*
betta_grad
,
hipStream_t
stream
[
2
],
T
*
inp_grad_out
,
const
T
*
norm_in
=
nullptr
)
{
launch_layerNorm_backward_fused_add
(
out_grad1
,
out_grad2
,
norm_in
,
vars
,
means
,
gamma
,
gamma_grad
,
betta_grad
,
inp_grad_out
,
bsz
,
config_
.
hiddenDim
,
stream
);
}
void
BackwardFusedAdd
(
int
bsz
,
const
T
*
out_grad1
,
const
T
*
out_grad2
,
const
T
*
gamma
,
const
T
*
betta
,
T
*
gamma_grad
,
T
*
betta_grad
,
hipStream_t
stream
[
2
],
T
*
inp_grad_out
,
const
T
*
norm_out
)
{
launch_layerNorm_backward_fused_add
(
out_grad1
,
out_grad2
,
norm_out
,
vars
,
gamma
,
gamma_grad
,
betta_grad
,
inp_grad_out
,
bsz
,
config_
.
hiddenDim
,
stream
,
!
config_
.
useMean
,
betta
);
}
inline
bool
UseMean
()
const
{
return
config_
.
useMean
;
}
inline
void
SetVar
(
T
*
variance
)
{
if
(
!
variance
)
{
throw
std
::
runtime_error
(
"Normalize variance is null."
);
}
vars
=
variance
;
}
inline
void
SetMean
(
T
*
mean
)
{
if
(
!
mean
)
{
throw
std
::
runtime_error
(
"Normalize mean is null."
);
}
means
=
mean
;
}
private:
Config
config_
;
T
*
vars
;
T
*
means
;
T
*
vals_hat
;
};
deepspeed/ops/csrc/includes/hip/softmax.h
0 → 100644
View file @
eadbbe09
#pragma once
#include <hip/hip_runtime.h>
#include <hip/hip_fp16.h>
#include <stdio.h>
#include "custom_cuda_layers.h"
#include <fstream>
using
namespace
std
;
template
<
typename
T
>
class
Softmax
{
public:
struct
Config
{
size_t
batchSize
;
size_t
heads
;
size_t
seq_length
;
size_t
prob_depth
;
float
temprature
;
bool
mem_alloc
;
Config
(
size_t
batch
,
size_t
h
,
size_t
seq
,
int
prob_size
=
0
,
bool
mem_alloc
=
false
)
:
batchSize
(
batch
),
heads
(
h
),
seq_length
(
seq
),
prob_depth
(
prob_size
),
temprature
(
1.0
),
mem_alloc
(
mem_alloc
)
{
}
};
Softmax
(
Config
config
)
:
config_
(
config
)
{}
~
Softmax
()
{}
void
Forward
(
int
bsz
,
T
*
vals
,
const
T
*
attn_mask
,
hipStream_t
&
stream
)
{
launch_attn_softmax
<
T
>
(
vals
,
attn_mask
,
bsz
,
config_
.
heads
,
config_
.
seq_length
,
stream
);
}
void
Backward
(
int
bsz
,
T
*
out_grad
,
const
T
*
soft_out
,
hipStream_t
stream
)
{
launch_attn_softmax_backward_v2
<
T
>
(
out_grad
,
soft_out
,
bsz
,
config_
.
heads
,
config_
.
seq_length
,
stream
);
}
inline
size_t
GetProbDepth
()
const
{
return
config_
.
prob_depth
;
}
inline
size_t
GetBatchSize
()
const
{
return
config_
.
batchSize
;
}
inline
size_t
GetNumHeads
()
const
{
return
config_
.
heads
;
}
inline
size_t
GetSeqLength
()
const
{
return
config_
.
seq_length
;
}
inline
void
SetSeqLength
(
size_t
seq_len
)
{
config_
.
seq_length
=
seq_len
;
}
private:
Config
config_
;
};
deepspeed/ops/csrc/includes/hip/strided_batch_gemm.h
0 → 100644
View file @
eadbbe09
#pragma once
#include <hip/hip_runtime.h>
#include <hip/hip_fp16.h>
#include <stdio.h>
#include "context.h"
template
<
typename
T
>
class
StridedBatchGemm
{
public:
struct
Config
{
int
batch_size
;
int
m
;
int
n
;
int
k
;
float
alpha
;
float
beta
;
rocblas_operation
op_A
;
rocblas_operation
op_B
;
std
::
array
<
int
,
3
>
gemm_algos
;
Config
(
int
batch
,
int
mm
,
int
nn
,
int
kk
,
float
param_alpha
,
float
param_beta
,
rocblas_operation
opA
,
rocblas_operation
opB
,
const
std
::
array
<
int
,
3
>&
algos
)
:
batch_size
(
batch
),
m
(
mm
),
n
(
nn
),
k
(
kk
),
alpha
(
param_alpha
),
beta
(
param_beta
),
op_A
(
opA
),
op_B
(
opB
),
gemm_algos
(
algos
)
{
}
void
SetConfig
(
int
mm
,
int
nn
,
int
kk
)
{
m
=
mm
;
n
=
nn
;
k
=
kk
;
}
};
StridedBatchGemm
(
const
Config
&
config
)
:
_config
(
config
)
{}
virtual
~
StridedBatchGemm
()
{}
void
Forward
(
int
bsz
,
T
*
output
,
const
T
*
_buffer_a
,
const
T
*
_buffer_b
,
rocblas_handle
handle
)
{
int
stride_a
=
_config
.
m
*
_config
.
k
;
int
stride_b
=
_config
.
n
*
_config
.
k
;
int
stride_c
=
_config
.
m
*
_config
.
n
;
cublas_strided_batched_gemm
(
handle
,
_config
.
m
,
_config
.
n
,
_config
.
k
,
&
_config
.
alpha
,
&
_config
.
beta
,
_buffer_a
,
_buffer_b
,
output
,
_config
.
op_A
,
_config
.
op_B
,
stride_a
,
stride_b
,
stride_c
,
bsz
,
cublasGemmAlgo_t
(
_config
.
gemm_algos
[
0
]));
}
void
ForwardPlusSave
(
T
*
output
,
const
T
*
_buffer_a
,
const
T
*
_buffer_b
,
rocblas_handle
handle
)
{
int
stride_a
=
_config
.
m
*
_config
.
k
;
int
stride_b
=
_config
.
n
*
_config
.
k
;
int
stride_c
=
_config
.
m
*
_config
.
n
;
cublas_strided_batched_gemm
(
handle
,
_config
.
m
,
_config
.
n
,
_config
.
k
,
&
_config
.
alpha
,
&
_config
.
beta
,
_buffer_a
,
_buffer_b
,
output
,
_config
.
op_A
,
_config
.
op_B
,
stride_a
,
stride_b
,
stride_c
,
_config
.
batch_size
,
cublasGemmAlgo_t
(
_config
.
gemm_algos
[
0
]));
k_buf
=
_buffer_a
;
q_buf
=
_buffer_b
;
}
void
Backward
(
int
bsz
,
const
T
*
d_output
,
const
T
*
_buffer_a
,
const
T
*
_buffer_b
,
rocblas_handle
handle
,
T
*
inpGradA
=
nullptr
,
T
*
inpGradB
=
nullptr
)
{
int
mb
=
(
_config
.
op_A
==
rocblas_operation_transpose
?
_config
.
k
:
_config
.
m
);
int
kb
=
(
_config
.
op_A
==
rocblas_operation_transpose
?
_config
.
m
:
_config
.
k
);
int
stride_a
=
mb
*
_config
.
n
;
int
stride_b
=
_config
.
n
*
kb
;
int
stride_c
=
_config
.
m
*
_config
.
k
;
// B need to transpose.
rocblas_operation
op_b
=
(
_config
.
op_B
==
rocblas_operation_transpose
?
rocblas_operation_none
:
rocblas_operation_transpose
);
// Calculate d_A.
cublas_strided_batched_gemm
(
handle
,
mb
,
kb
,
_config
.
n
,
&
_config
.
alpha
,
&
_config
.
beta
,
(
_config
.
op_A
==
rocblas_operation_transpose
?
_buffer_b
:
d_output
),
(
_config
.
op_A
==
rocblas_operation_transpose
?
d_output
:
_buffer_b
),
inpGradA
,
rocblas_operation_none
,
op_b
,
stride_a
,
stride_b
,
stride_c
,
bsz
,
cublasGemmAlgo_t
(
_config
.
gemm_algos
[
1
]));
// A need to transpose.
rocblas_operation
op_a
=
(
_config
.
op_A
==
rocblas_operation_transpose
?
rocblas_operation_none
:
rocblas_operation_transpose
);
stride_a
=
_config
.
m
*
_config
.
k
;
stride_b
=
_config
.
m
*
_config
.
n
;
stride_c
=
_config
.
n
*
_config
.
k
;
// Calculate d_B.
cublas_strided_batched_gemm
(
handle
,
_config
.
k
,
_config
.
n
,
_config
.
m
,
&
_config
.
alpha
,
&
_config
.
beta
,
_buffer_a
,
d_output
,
inpGradB
,
op_a
,
rocblas_operation_none
,
stride_a
,
stride_b
,
stride_c
,
bsz
,
cublasGemmAlgo_t
(
_config
.
gemm_algos
[
2
]));
}
inline
int
GetN
()
const
{
return
_config
.
k
;
}
inline
const
T
*
GetBufferA
()
const
{
return
k_buf
;
}
inline
const
T
*
GetBufferB
()
const
{
return
q_buf
;
}
inline
void
SetConfig
(
int
m
,
int
n
,
int
k
)
{
_config
.
SetConfig
(
m
,
n
,
k
);
}
private:
Config
_config
;
const
T
*
q_buf
;
const
T
*
k_buf
;
};
deepspeed/ops/csrc/includes/hip/type_shim.h
0 → 100644
View file @
eadbbe09
#include "hip/hip_runtime.h"
/* Taken from NVIDIA/apex commit 855808f3fc268e9715d613f3c2e56469d8c986d8 */
#include <ATen/ATen.h>
// Forward/backward compatiblity hack around
// https://github.com/pytorch/pytorch/commit/3aeb78079bcd68282fe9117088e138b77318e288
// pending more future-proof guidance from upstream.
// struct TypeShim
// {
// const at::Type& payload;
// TypeShim(const at::Type& type) : payload(type) {}
// // Enable trivial conversion to a const at::Type& for pre-3aeb78
// operator const at::Type&(){ return payload; };
// // Enable dispatch switch statements to take *this directly for post-3aeb78
// //operator at::ScalarType(){ return payload.; };
// };
#define DISPATCH_FLOAT_AND_HALF(TYPE, LEVEL, NAME, ...) \
switch (TYPE) { \
case at::ScalarType::Float: { \
using scalar_t_##LEVEL = float; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Half: { \
using scalar_t_##LEVEL = at::Half; \
__VA_ARGS__; \
break; \
} \
default: AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
}
#define DISPATCH_DOUBLE_FLOAT_AND_HALF(TYPE, LEVEL, NAME, ...) \
switch (TYPE) { \
case at::ScalarType::Double: { \
using scalar_t_##LEVEL = double; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Float: { \
using scalar_t_##LEVEL = float; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Half: { \
using scalar_t_##LEVEL = at::Half; \
__VA_ARGS__; \
break; \
} \
default: AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
}
#define DISPATCH_DOUBLE_AND_FLOAT(TYPE, LEVEL, NAME, ...) \
switch (TYPE) { \
case at::ScalarType::Double: { \
using scalar_t_##LEVEL = double; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Float: { \
using scalar_t_##LEVEL = float; \
__VA_ARGS__; \
break; \
} \
default: AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
}
template
<
typename
T
>
__device__
__forceinline__
T
reduce_block_into_lanes
(
T
*
x
,
T
val
,
int
lanes
=
1
,
bool
share_result
=
false
)
// lanes is intended to be <= 32.
{
int
tid
=
threadIdx
.
x
+
threadIdx
.
y
*
blockDim
.
x
;
int
blockSize
=
blockDim
.
x
*
blockDim
.
y
;
// blockSize is intended to be a multiple of 32.
if
(
blockSize
>=
64
)
{
x
[
tid
]
=
val
;
__syncthreads
();
}
#pragma unroll
for
(
int
i
=
(
blockSize
>>
1
);
i
>=
64
;
i
>>=
1
)
{
if
(
tid
<
i
)
x
[
tid
]
=
x
[
tid
]
+
x
[
tid
+
i
];
__syncthreads
();
}
T
final
;
if
(
tid
<
32
)
{
if
(
blockSize
>=
64
)
final
=
x
[
tid
]
+
x
[
tid
+
32
];
else
final
=
val
;
// __SYNCWARP();
#pragma unroll
for
(
int
i
=
16
;
i
>=
lanes
;
i
>>=
1
)
final
=
final
+
__shfl_down_sync
(
0xffffffff
,
final
,
i
);
}
if
(
share_result
)
{
if
(
tid
<
lanes
)
x
[
tid
]
=
final
;
// EpilogueOp
// Make sure the smem result is visible to all warps.
__syncthreads
();
}
return
final
;
}
deepspeed/ops/csrc/includes/normalize_layer.h
0 → 100644
View file @
eadbbe09
#pragma once
#include <cuda.h>
#include <cuda_fp16.h>
#include <stdio.h>
#include <fstream>
#include "custom_cuda_layers.h"
using
namespace
std
;
template
<
typename
T
>
class
Normalize_Layer
{
public:
struct
Config
{
uint32_t
batchSize
;
uint32_t
seqLength
;
uint32_t
hiddenDim
;
float
epsilon
;
bool
training
;
bool
useMean
;
Config
(
uint32_t
batch
,
uint32_t
seq
,
uint32_t
h
,
float
epsilon
=
1e-12
,
bool
training
=
true
,
bool
useMean
=
true
)
:
batchSize
(
batch
),
seqLength
(
seq
),
hiddenDim
(
h
),
epsilon
(
epsilon
),
training
(
training
),
useMean
(
useMean
)
{
}
};
Normalize_Layer
(
Config
config
)
:
config_
(
config
),
vars
(
nullptr
),
means
(
nullptr
),
vals_hat
(
nullptr
)
{
}
~
Normalize_Layer
()
{}
void
ForwardCheckpoint
(
int
bsz
,
// batch * seq
T
*
vals
,
const
T
*
residual
,
const
T
*
gamma
,
const
T
*
betta
,
cudaStream_t
&
stream
,
bool
preLayerNorm
=
false
)
{
launch_bias_residual_layer_norm
(
vals
,
residual
,
gamma
,
betta
,
config_
.
epsilon
,
bsz
,
config_
.
hiddenDim
,
stream
,
preLayerNorm
,
config_
.
training
,
vars
,
means
);
}
void
Forward
(
int
bsz
,
T
*
vals
,
const
T
*
residual
,
const
T
*
gamma
,
const
T
*
betta
,
cudaStream_t
&
stream
,
bool
preLayerNorm
=
false
)
{
launch_bias_residual_layer_norm
(
vals
,
residual
,
gamma
,
betta
,
config_
.
epsilon
,
bsz
,
config_
.
hiddenDim
,
stream
,
preLayerNorm
,
config_
.
training
,
vars
);
}
void
Backward
(
int
bsz
,
const
T
*
out_grad
,
const
T
*
gamma
,
T
*
gamma_grad
,
T
*
betta_grad
,
cudaStream_t
stream
[
2
],
T
*
inp_grad_out
,
const
T
*
norm_in
=
nullptr
)
{
launch_layerNorm_backward
(
out_grad
,
norm_in
,
vars
,
means
,
gamma
,
gamma_grad
,
betta_grad
,
inp_grad_out
,
bsz
,
config_
.
hiddenDim
,
stream
);
}
void
Backward
(
int
bsz
,
const
T
*
out_grad
,
const
T
*
gamma
,
const
T
*
betta
,
T
*
gamma_grad
,
T
*
betta_grad
,
cudaStream_t
stream
[
2
],
T
*
inp_grad_out
,
const
T
*
norm_out
)
{
launch_layerNorm_backward
(
out_grad
,
norm_out
,
vars
,
gamma
,
gamma_grad
,
betta_grad
,
inp_grad_out
,
bsz
,
config_
.
hiddenDim
,
stream
,
!
config_
.
useMean
,
betta
);
}
void
BackwardFusedAdd
(
int
bsz
,
const
T
*
out_grad1
,
const
T
*
out_grad2
,
const
T
*
gamma
,
T
*
gamma_grad
,
T
*
betta_grad
,
cudaStream_t
stream
[
2
],
T
*
inp_grad_out
,
const
T
*
norm_in
=
nullptr
)
{
launch_layerNorm_backward_fused_add
(
out_grad1
,
out_grad2
,
norm_in
,
vars
,
means
,
gamma
,
gamma_grad
,
betta_grad
,
inp_grad_out
,
bsz
,
config_
.
hiddenDim
,
stream
);
}
void
BackwardFusedAdd
(
int
bsz
,
const
T
*
out_grad1
,
const
T
*
out_grad2
,
const
T
*
gamma
,
const
T
*
betta
,
T
*
gamma_grad
,
T
*
betta_grad
,
cudaStream_t
stream
[
2
],
T
*
inp_grad_out
,
const
T
*
norm_out
)
{
launch_layerNorm_backward_fused_add
(
out_grad1
,
out_grad2
,
norm_out
,
vars
,
gamma
,
gamma_grad
,
betta_grad
,
inp_grad_out
,
bsz
,
config_
.
hiddenDim
,
stream
,
!
config_
.
useMean
,
betta
);
}
inline
bool
UseMean
()
const
{
return
config_
.
useMean
;
}
inline
void
SetVar
(
T
*
variance
)
{
if
(
!
variance
)
{
throw
std
::
runtime_error
(
"Normalize variance is null."
);
}
vars
=
variance
;
}
inline
void
SetMean
(
T
*
mean
)
{
if
(
!
mean
)
{
throw
std
::
runtime_error
(
"Normalize mean is null."
);
}
means
=
mean
;
}
private:
Config
config_
;
T
*
vars
;
T
*
means
;
T
*
vals_hat
;
};
deepspeed/ops/csrc/includes/softmax.h
0 → 100755
View file @
eadbbe09
#pragma once
#include <cuda.h>
#include <cuda_fp16.h>
#include <stdio.h>
#include "custom_cuda_layers.h"
#include <fstream>
using
namespace
std
;
template
<
typename
T
>
class
Softmax
{
public:
struct
Config
{
size_t
batchSize
;
size_t
heads
;
size_t
seq_length
;
size_t
prob_depth
;
float
temprature
;
bool
mem_alloc
;
Config
(
size_t
batch
,
size_t
h
,
size_t
seq
,
int
prob_size
=
0
,
bool
mem_alloc
=
false
)
:
batchSize
(
batch
),
heads
(
h
),
seq_length
(
seq
),
prob_depth
(
prob_size
),
temprature
(
1.0
),
mem_alloc
(
mem_alloc
)
{
}
};
Softmax
(
Config
config
)
:
config_
(
config
)
{}
~
Softmax
()
{}
void
Forward
(
int
bsz
,
T
*
vals
,
const
T
*
attn_mask
,
cudaStream_t
&
stream
)
{
launch_attn_softmax
<
T
>
(
vals
,
attn_mask
,
bsz
,
config_
.
heads
,
config_
.
seq_length
,
stream
);
}
void
Backward
(
int
bsz
,
T
*
out_grad
,
const
T
*
soft_out
,
cudaStream_t
stream
)
{
launch_attn_softmax_backward_v2
<
T
>
(
out_grad
,
soft_out
,
bsz
,
config_
.
heads
,
config_
.
seq_length
,
stream
);
}
inline
size_t
GetProbDepth
()
const
{
return
config_
.
prob_depth
;
}
inline
size_t
GetBatchSize
()
const
{
return
config_
.
batchSize
;
}
inline
size_t
GetNumHeads
()
const
{
return
config_
.
heads
;
}
inline
size_t
GetSeqLength
()
const
{
return
config_
.
seq_length
;
}
inline
void
SetSeqLength
(
size_t
seq_len
)
{
config_
.
seq_length
=
seq_len
;
}
private:
Config
config_
;
};
deepspeed/ops/csrc/includes/strided_batch_gemm.h
0 → 100644
View file @
eadbbe09
#pragma once
#include <cuda.h>
#include <cuda_fp16.h>
#include <stdio.h>
#include "context.h"
template
<
typename
T
>
class
StridedBatchGemm
{
public:
struct
Config
{
int
batch_size
;
int
m
;
int
n
;
int
k
;
float
alpha
;
float
beta
;
cublasOperation_t
op_A
;
cublasOperation_t
op_B
;
std
::
array
<
int
,
3
>
gemm_algos
;
Config
(
int
batch
,
int
mm
,
int
nn
,
int
kk
,
float
param_alpha
,
float
param_beta
,
cublasOperation_t
opA
,
cublasOperation_t
opB
,
const
std
::
array
<
int
,
3
>&
algos
)
:
batch_size
(
batch
),
m
(
mm
),
n
(
nn
),
k
(
kk
),
alpha
(
param_alpha
),
beta
(
param_beta
),
op_A
(
opA
),
op_B
(
opB
),
gemm_algos
(
algos
)
{
}
void
SetConfig
(
int
mm
,
int
nn
,
int
kk
)
{
m
=
mm
;
n
=
nn
;
k
=
kk
;
}
};
StridedBatchGemm
(
const
Config
&
config
)
:
_config
(
config
)
{}
virtual
~
StridedBatchGemm
()
{}
void
Forward
(
int
bsz
,
T
*
output
,
const
T
*
_buffer_a
,
const
T
*
_buffer_b
,
cublasHandle_t
handle
)
{
int
stride_a
=
_config
.
m
*
_config
.
k
;
int
stride_b
=
_config
.
n
*
_config
.
k
;
int
stride_c
=
_config
.
m
*
_config
.
n
;
cublas_strided_batched_gemm
(
handle
,
_config
.
m
,
_config
.
n
,
_config
.
k
,
&
_config
.
alpha
,
&
_config
.
beta
,
_buffer_a
,
_buffer_b
,
output
,
_config
.
op_A
,
_config
.
op_B
,
stride_a
,
stride_b
,
stride_c
,
bsz
,
cublasGemmAlgo_t
(
_config
.
gemm_algos
[
0
]));
}
void
ForwardPlusSave
(
T
*
output
,
const
T
*
_buffer_a
,
const
T
*
_buffer_b
,
cublasHandle_t
handle
)
{
int
stride_a
=
_config
.
m
*
_config
.
k
;
int
stride_b
=
_config
.
n
*
_config
.
k
;
int
stride_c
=
_config
.
m
*
_config
.
n
;
cublas_strided_batched_gemm
(
handle
,
_config
.
m
,
_config
.
n
,
_config
.
k
,
&
_config
.
alpha
,
&
_config
.
beta
,
_buffer_a
,
_buffer_b
,
output
,
_config
.
op_A
,
_config
.
op_B
,
stride_a
,
stride_b
,
stride_c
,
_config
.
batch_size
,
cublasGemmAlgo_t
(
_config
.
gemm_algos
[
0
]));
k_buf
=
_buffer_a
;
q_buf
=
_buffer_b
;
}
void
Backward
(
int
bsz
,
const
T
*
d_output
,
const
T
*
_buffer_a
,
const
T
*
_buffer_b
,
cublasHandle_t
handle
,
T
*
inpGradA
=
nullptr
,
T
*
inpGradB
=
nullptr
)
{
int
mb
=
(
_config
.
op_A
==
CUBLAS_OP_T
?
_config
.
k
:
_config
.
m
);
int
kb
=
(
_config
.
op_A
==
CUBLAS_OP_T
?
_config
.
m
:
_config
.
k
);
int
stride_a
=
mb
*
_config
.
n
;
int
stride_b
=
_config
.
n
*
kb
;
int
stride_c
=
_config
.
m
*
_config
.
k
;
// B need to transpose.
cublasOperation_t
op_b
=
(
_config
.
op_B
==
CUBLAS_OP_T
?
CUBLAS_OP_N
:
CUBLAS_OP_T
);
// Calculate d_A.
cublas_strided_batched_gemm
(
handle
,
mb
,
kb
,
_config
.
n
,
&
_config
.
alpha
,
&
_config
.
beta
,
(
_config
.
op_A
==
CUBLAS_OP_T
?
_buffer_b
:
d_output
),
(
_config
.
op_A
==
CUBLAS_OP_T
?
d_output
:
_buffer_b
),
inpGradA
,
CUBLAS_OP_N
,
op_b
,
stride_a
,
stride_b
,
stride_c
,
bsz
,
cublasGemmAlgo_t
(
_config
.
gemm_algos
[
1
]));
// A need to transpose.
cublasOperation_t
op_a
=
(
_config
.
op_A
==
CUBLAS_OP_T
?
CUBLAS_OP_N
:
CUBLAS_OP_T
);
stride_a
=
_config
.
m
*
_config
.
k
;
stride_b
=
_config
.
m
*
_config
.
n
;
stride_c
=
_config
.
n
*
_config
.
k
;
// Calculate d_B.
cublas_strided_batched_gemm
(
handle
,
_config
.
k
,
_config
.
n
,
_config
.
m
,
&
_config
.
alpha
,
&
_config
.
beta
,
_buffer_a
,
d_output
,
inpGradB
,
op_a
,
CUBLAS_OP_N
,
stride_a
,
stride_b
,
stride_c
,
bsz
,
cublasGemmAlgo_t
(
_config
.
gemm_algos
[
2
]));
}
inline
int
GetN
()
const
{
return
_config
.
k
;
}
inline
const
T
*
GetBufferA
()
const
{
return
k_buf
;
}
inline
const
T
*
GetBufferB
()
const
{
return
q_buf
;
}
inline
void
SetConfig
(
int
m
,
int
n
,
int
k
)
{
_config
.
SetConfig
(
m
,
n
,
k
);
}
private:
Config
_config
;
const
T
*
q_buf
;
const
T
*
k_buf
;
};
deepspeed/ops/csrc/includes/type_shim.h
0 → 100644
View file @
eadbbe09
/* Taken from NVIDIA/apex commit 855808f3fc268e9715d613f3c2e56469d8c986d8 */
#include <ATen/ATen.h>
// Forward/backward compatiblity hack around
// https://github.com/pytorch/pytorch/commit/3aeb78079bcd68282fe9117088e138b77318e288
// pending more future-proof guidance from upstream.
// struct TypeShim
// {
// const at::Type& payload;
// TypeShim(const at::Type& type) : payload(type) {}
// // Enable trivial conversion to a const at::Type& for pre-3aeb78
// operator const at::Type&(){ return payload; };
// // Enable dispatch switch statements to take *this directly for post-3aeb78
// //operator at::ScalarType(){ return payload.; };
// };
#define DISPATCH_FLOAT_AND_HALF(TYPE, LEVEL, NAME, ...) \
switch (TYPE) { \
case at::ScalarType::Float: { \
using scalar_t_##LEVEL = float; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Half: { \
using scalar_t_##LEVEL = at::Half; \
__VA_ARGS__; \
break; \
} \
default: AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
}
#define DISPATCH_DOUBLE_FLOAT_AND_HALF(TYPE, LEVEL, NAME, ...) \
switch (TYPE) { \
case at::ScalarType::Double: { \
using scalar_t_##LEVEL = double; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Float: { \
using scalar_t_##LEVEL = float; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Half: { \
using scalar_t_##LEVEL = at::Half; \
__VA_ARGS__; \
break; \
} \
default: AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
}
#define DISPATCH_DOUBLE_AND_FLOAT(TYPE, LEVEL, NAME, ...) \
switch (TYPE) { \
case at::ScalarType::Double: { \
using scalar_t_##LEVEL = double; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Float: { \
using scalar_t_##LEVEL = float; \
__VA_ARGS__; \
break; \
} \
default: AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
}
template
<
typename
T
>
__device__
__forceinline__
T
reduce_block_into_lanes
(
T
*
x
,
T
val
,
int
lanes
=
1
,
bool
share_result
=
false
)
// lanes is intended to be <= 32.
{
int
tid
=
threadIdx
.
x
+
threadIdx
.
y
*
blockDim
.
x
;
int
blockSize
=
blockDim
.
x
*
blockDim
.
y
;
// blockSize is intended to be a multiple of 32.
if
(
blockSize
>=
64
)
{
x
[
tid
]
=
val
;
__syncthreads
();
}
#pragma unroll
for
(
int
i
=
(
blockSize
>>
1
);
i
>=
64
;
i
>>=
1
)
{
if
(
tid
<
i
)
x
[
tid
]
=
x
[
tid
]
+
x
[
tid
+
i
];
__syncthreads
();
}
T
final
;
if
(
tid
<
32
)
{
if
(
blockSize
>=
64
)
final
=
x
[
tid
]
+
x
[
tid
+
32
];
else
final
=
val
;
// __SYNCWARP();
#pragma unroll
for
(
int
i
=
16
;
i
>=
lanes
;
i
>>=
1
)
final
=
final
+
__shfl_down_sync
(
0xffffffff
,
final
,
i
);
}
if
(
share_result
)
{
if
(
tid
<
lanes
)
x
[
tid
]
=
final
;
// EpilogueOp
// Make sure the smem result is visible to all warps.
__syncthreads
();
}
return
final
;
}
deepspeed/ops/csrc/lamb/fused_lamb_cuda.cpp
0 → 100644
View file @
eadbbe09
/* Copyright 2019 The Microsoft DeepSpeed Team */
#include <torch/extension.h>
// CUDA forward declaration
void
fused_lamb_cuda
(
at
::
Tensor
&
p
,
at
::
Tensor
&
p_copy
,
at
::
Tensor
&
m
,
at
::
Tensor
&
v
,
at
::
Tensor
&
g
,
float
lr
,
float
beta1
,
float
beta2
,
float
max_coeff
,
float
min_coeff
,
float
eps
,
float
grad_scale
,
int
step
,
int
mode
,
int
bias_correction
,
float
decay
,
at
::
Tensor
&
w_l2_i
,
at
::
Tensor
&
u_l2_i
,
at
::
Tensor
&
lamb_coeff_val
);
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) \
CHECK_CUDA(x); \
CHECK_CONTIGUOUS(x)
// C++ interface
at
::
Tensor
lamb
(
at
::
Tensor
&
p
,
at
::
Tensor
&
p_copy
,
at
::
Tensor
&
m
,
at
::
Tensor
&
v
,
at
::
Tensor
&
g
,
float
lr
,
float
beta1
,
float
beta2
,
float
max_coeff
,
float
min_coeff
,
float
eps
,
float
grad_scale
,
int
step
,
int
mode
,
int
bias_correction
,
float
decay
)
{
CHECK_INPUT
(
p
);
if
(
p_copy
.
numel
()
>
0
)
CHECK_INPUT
(
p_copy
);
CHECK_INPUT
(
m
);
CHECK_INPUT
(
v
);
CHECK_INPUT
(
g
);
int64_t
num_elem
=
p
.
numel
();
AT_ASSERTM
(
m
.
numel
()
==
num_elem
,
"number of elements in m and p tensors should be equal"
);
AT_ASSERTM
(
v
.
numel
()
==
num_elem
,
"number of elements in v and p tensors should be equal"
);
AT_ASSERTM
(
g
.
numel
()
==
num_elem
,
"number of elements in g and p tensors should be equal"
);
AT_ASSERTM
(
p_copy
.
numel
()
==
num_elem
||
p_copy
.
numel
()
==
0
,
"number of elements in p_copy and p tensors should be equal, or p_copy should be empty"
);
// intermediate for weight L2 reduction
// make sure that the threads per block is at least 512 during the kernel launch otherwise the
// behavious is unexpected
at
::
Tensor
w_l2_i
=
at
::
empty
(
{
512
},
p
.
options
().
dtype
(
p
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
?
at
::
ScalarType
::
Float
:
p
.
type
().
scalarType
()));
// intermediate for update L2 reduction
// make sure that the threads per block is at least 512 during the kernel launch otherwise the
// behavious is unexpected
at
::
Tensor
u_l2_i
=
at
::
empty
(
{
512
},
p
.
options
().
dtype
(
p
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
?
at
::
ScalarType
::
Float
:
p
.
type
().
scalarType
()));
at
::
Tensor
lamb_coeff_val
=
at
::
empty
(
{
1
},
p
.
options
().
dtype
(
p
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
?
at
::
ScalarType
::
Float
:
p
.
type
().
scalarType
()));
fused_lamb_cuda
(
p
,
p_copy
,
m
,
v
,
g
,
lr
,
beta1
,
beta2
,
max_coeff
,
min_coeff
,
eps
,
grad_scale
,
step
,
mode
,
bias_correction
,
decay
,
w_l2_i
,
u_l2_i
,
lamb_coeff_val
);
return
lamb_coeff_val
;
}
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"lamb"
,
&
lamb
,
"Adam optimized CUDA implementation with LAMB."
);
}
deepspeed/ops/csrc/lamb/fused_lamb_cuda_kernel.cu
0 → 100644
View file @
eadbbe09
/* Copyright 2019 The Microsoft DeepSpeed Team */
#include <cuda.h>
#include <cuda_runtime.h>
#include <stdio.h>
#include <cmath>
#include "ATen/ATen.h"
#include "ATen/TensorUtils.h"
#include "ATen/cuda/CUDAContext.h"
#include "ATen/cuda/detail/IndexUtils.cuh"
//#include "ATen/Type.h"
#include <THC/THCGeneral.h>
#include "ATen/AccumulateType.h"
#include <iostream>
//#include <helper_functions.h>
#include <cooperative_groups.h>
#include <cuda_runtime_api.h>
#include <stdio.h>
namespace
cg
=
cooperative_groups
;
// Utility class used to avoid linker errors with extern
// unsized shared memory arrays with templated type
namespace
{
// This is the un-specialized struct. Note that we prevent instantiation of this
// struct by putting an undefined symbol in the function body so it won't compile.
template
<
typename
T
>
struct
SharedMemory
{
// Ensure that we won't compile any un-specialized types
__device__
inline
operator
T
*
()
{
extern
__device__
void
error
(
void
);
error
();
return
NULL
;
}
};
template
<
>
struct
SharedMemory
<
float
>
{
__device__
inline
operator
float
*
()
{
extern
__shared__
float
s_float
[];
return
s_float
;
}
};
template
<
>
struct
SharedMemory
<
double
>
{
__device__
inline
operator
double
*
()
{
extern
__shared__
double
s_double
[];
return
s_double
;
}
};
}
// namespace
#include "type_shim.h"
typedef
enum
{
ADAM_MODE_0
=
0
,
// eps under square root
ADAM_MODE_1
=
1
// eps outside square root
}
adamMode_t
;
// s_a and s_b are in shared memory
// g_a and g_b are in shared memory
template
<
typename
T
,
int
blockSize
>
__device__
void
reduce_block_in_shared_memory
(
T
*
s_a
,
T
*
s_b
,
T
*
g_a
,
T
*
g_b
)
{
// Handle to thread block group
cg
::
thread_block
cta
=
cg
::
this_thread_block
();
// perform block reduction in shared memory,
unsigned
int
tid
=
cta
.
thread_rank
();
T
a_sum
=
s_a
[
tid
];
T
b_sum
=
s_b
[
tid
];
cg
::
sync
(
cta
);
// do reduction in shared mem
if
((
blockSize
>=
512
)
&&
(
tid
<
256
))
{
s_a
[
tid
]
=
a_sum
=
a_sum
+
s_a
[
tid
+
256
];
s_b
[
tid
]
=
b_sum
=
b_sum
+
s_b
[
tid
+
256
];
}
cg
::
sync
(
cta
);
if
((
blockSize
>=
256
)
&&
(
tid
<
128
))
{
s_a
[
tid
]
=
a_sum
=
a_sum
+
s_a
[
tid
+
128
];
s_b
[
tid
]
=
b_sum
=
b_sum
+
s_b
[
tid
+
128
];
}
cg
::
sync
(
cta
);
if
((
blockSize
>=
128
)
&&
(
tid
<
64
))
{
s_a
[
tid
]
=
a_sum
=
a_sum
+
s_a
[
tid
+
64
];
s_b
[
tid
]
=
b_sum
=
b_sum
+
s_b
[
tid
+
64
];
}
cg
::
sync
(
cta
);
#if (__CUDA_ARCH__ >= 300)
if
(
tid
<
32
)
{
cg
::
coalesced_group
active
=
cg
::
coalesced_threads
();
// Fetch final intermediate sum from 2nd warp
if
(
blockSize
>=
64
)
{
a_sum
=
a_sum
+
s_a
[
tid
+
32
];
b_sum
=
b_sum
+
s_b
[
tid
+
32
];
}
// Reduce final warp using shuffle
for
(
int
offset
=
warpSize
/
2
;
offset
>
0
;
offset
/=
2
)
{
a_sum
+=
active
.
shfl_down
(
a_sum
,
offset
);
b_sum
+=
active
.
shfl_down
(
b_sum
,
offset
);
}
}
#else
if
((
blockSize
>=
64
)
&&
(
tid
<
32
))
{
s_a
[
tid
]
=
a_sum
=
a_sum
+
s_a
[
tid
+
32
];
s_b
[
tid
]
=
b_sum
=
b_sum
+
s_b
[
tid
+
32
];
}
cg
::
sync
(
cta
);
if
((
blockSize
>=
32
)
&&
(
tid
<
16
))
{
s_a
[
tid
]
=
a_sum
=
a_sum
+
s_a
[
tid
+
16
];
s_b
[
tid
]
=
b_sum
=
b_sum
+
s_b
[
tid
+
16
];
}
cg
::
sync
(
cta
);
if
((
blockSize
>=
16
)
&&
(
tid
<
8
))
{
s_a
[
tid
]
=
a_sum
=
a_sum
+
s_a
[
tid
+
8
];
s_b
[
tid
]
=
b_sum
=
b_sum
+
s_b
[
tid
+
8
];
}
cg
::
sync
(
cta
);
if
((
blockSize
>=
8
)
&&
(
tid
<
4
))
{
s_a
[
tid
]
=
a_sum
=
a_sum
+
s_a
[
tid
+
4
];
s_b
[
tid
]
=
b_sum
=
b_sum
+
s_b
[
tid
+
4
];
}
cg
::
sync
(
cta
);
if
((
blockSize
>=
4
)
&&
(
tid
<
2
))
{
s_a
[
tid
]
=
a_sum
=
a_sum
+
s_a
[
tid
+
2
];
s_b
[
tid
]
=
b_sum
=
b_sum
+
s_b
[
tid
+
2
];
}
cg
::
sync
(
cta
);
if
((
blockSize
>=
2
)
&&
(
tid
<
1
))
{
s_a
[
tid
]
=
a_sum
=
a_sum
+
s_a
[
tid
+
1
];
s_b
[
tid
]
=
b_sum
=
b_sum
+
s_b
[
tid
+
1
];
}
cg
::
sync
(
cta
);
#endif
// write result for this block to global mem
if
(
tid
==
0
)
{
g_a
[
blockIdx
.
x
]
=
(
T
)
a_sum
;
g_b
[
blockIdx
.
x
]
=
(
T
)
b_sum
;
}
}
template
<
typename
T
,
int
blockSize
>
__device__
void
reduce_two_vectors_in_register
(
T
a
,
T
b
,
T
*
g_a
,
T
*
g_b
)
{
const
int
threadIdInBlock
=
cg
::
this_thread_block
().
thread_rank
();
T
*
s_a
=
SharedMemory
<
T
>
();
T
*
s_b
=
SharedMemory
<
T
>
()
+
cg
::
this_thread_block
().
size
();
s_a
[
threadIdInBlock
]
=
a
;
s_b
[
threadIdInBlock
]
=
b
;
reduce_block_in_shared_memory
<
T
,
blockSize
>
(
s_a
,
s_b
,
g_a
,
g_b
);
}
template
<
typename
T
,
typename
GRAD_T
,
int
blockSize
>
__global__
void
lamb_cuda_kernel_part1
(
T
*
__restrict__
p
,
GRAD_T
*
__restrict__
p_copy
,
// For mixed precision training, pass NULL if not needed
T
*
__restrict__
m
,
T
*
__restrict__
v
,
const
GRAD_T
*
__restrict__
g
,
const
float
b1
,
const
float
b2
,
const
float
eps
,
const
float
grad_scale
,
const
float
step_size
,
const
size_t
tsize
,
adamMode_t
mode
,
const
float
decay
,
T
*
__restrict__
w_l2_i
,
T
*
__restrict__
u_l2_i
)
{
// Assuming 2D grids and 2D blocks
const
int
blockId
=
gridDim
.
x
*
blockIdx
.
y
+
blockIdx
.
x
;
const
int
threadsPerBlock
=
blockDim
.
x
*
blockDim
.
y
;
const
int
threadIdInBlock
=
cg
::
this_thread_block
().
thread_rank
();
const
int
i
=
(
blockId
*
threadsPerBlock
+
threadIdInBlock
);
const
int
totThreads
=
gridDim
.
x
*
gridDim
.
y
*
threadsPerBlock
;
T
reg_w
=
0
;
T
reg_u
=
0
;
for
(
int
j
=
i
;
j
<
tsize
;
j
+=
totThreads
)
{
T
scaled_grad
=
g
[
j
]
/
grad_scale
;
T
pj
=
p
[
j
];
m
[
j
]
=
b1
*
m
[
j
]
+
(
1
-
b1
)
*
scaled_grad
;
v
[
j
]
=
b2
*
v
[
j
]
+
(
1
-
b2
)
*
scaled_grad
*
scaled_grad
;
float
denom
;
if
(
mode
==
ADAM_MODE_0
)
denom
=
sqrtf
(
v
[
j
]
+
eps
);
else
// Mode 1
denom
=
sqrtf
(
v
[
j
])
+
eps
;
T
update
=
(
m
[
j
]
/
denom
)
+
(
decay
*
p
[
j
]);
reg_u
+=
update
*
update
;
reg_w
+=
pj
*
pj
;
}
reduce_two_vectors_in_register
<
T
,
blockSize
>
(
reg_w
,
reg_u
,
w_l2_i
,
u_l2_i
);
}
template
<
typename
T
,
typename
GRAD_T
,
int
blockSize
>
__global__
void
lamb_cuda_kernel_part2
(
const
size_t
tsize
,
T
*
__restrict__
g_a
,
T
*
__restrict__
g_b
)
{
T
*
s_a
=
SharedMemory
<
T
>
();
T
*
s_b
=
SharedMemory
<
T
>
()
+
cg
::
this_thread_block
().
size
();
const
int
threadIdInBlock
=
cg
::
this_thread_block
().
thread_rank
();
s_a
[
threadIdInBlock
]
=
g_a
[
threadIdInBlock
];
s_b
[
threadIdInBlock
]
=
g_b
[
threadIdInBlock
];
if
(
threadIdInBlock
>=
tsize
)
{
s_a
[
threadIdInBlock
]
=
0.0
;
s_b
[
threadIdInBlock
]
=
0.0
;
}
reduce_block_in_shared_memory
<
T
,
blockSize
>
(
s_a
,
s_b
,
g_a
,
g_b
);
}
template
<
typename
T
,
typename
GRAD_T
>
__global__
void
lamb_cuda_kernel_part3
(
T
*
__restrict__
p
,
GRAD_T
*
__restrict__
p_copy
,
// For mixed precision training, pass NULL if not needed
T
*
__restrict__
m
,
T
*
__restrict__
v
,
const
GRAD_T
*
__restrict__
g
,
const
float
b1
,
const
float
b2
,
const
float
max_coeff
,
const
float
min_coeff
,
const
float
eps
,
const
float
grad_scale
,
const
float
step_size
,
const
size_t
tsize
,
adamMode_t
mode
,
const
float
decay
,
T
*
__restrict__
w_l2_i
,
T
*
__restrict__
u_l2_i
,
T
*
__restrict__
lamb_coeff_val
)
{
// Assuming 2D grids and 2D blocks
const
int
blockId
=
gridDim
.
x
*
blockIdx
.
y
+
blockIdx
.
x
;
const
int
threadsPerBlock
=
blockDim
.
x
*
blockDim
.
y
;
const
int
threadIdInBlock
=
cg
::
this_thread_block
().
thread_rank
();
const
int
i
=
(
blockId
*
threadsPerBlock
+
threadIdInBlock
);
const
int
totThreads
=
gridDim
.
x
*
gridDim
.
y
*
threadsPerBlock
;
T
reg_w
=
sqrtf
(
w_l2_i
[
0
]);
T
reg_u
=
sqrtf
(
u_l2_i
[
0
]);
float
lamb_coeff
=
1.0
;
if
(
reg_w
!=
0
and
reg_u
!=
0
)
{
lamb_coeff
=
reg_w
/
reg_u
;
if
(
lamb_coeff
>
max_coeff
)
{
lamb_coeff
=
max_coeff
;
}
if
(
lamb_coeff
<
min_coeff
)
{
lamb_coeff
=
min_coeff
;
}
}
if
(
blockId
==
0
and
threadIdInBlock
==
0
)
{
lamb_coeff_val
[
0
]
=
lamb_coeff
;
// printf("Cuda Lamb Coeff is %.6f \n",lamb_coeff);
}
for
(
int
j
=
i
;
j
<
tsize
;
j
+=
totThreads
)
{
T
pj
=
(
float
)
p
[
j
];
T
mj
=
m
[
j
];
T
vj
=
v
[
j
];
float
denom
;
if
(
mode
==
ADAM_MODE_0
)
denom
=
sqrtf
(
vj
+
eps
);
else
// Mode 1
denom
=
sqrtf
(
vj
)
+
eps
;
T
update
=
(
mj
/
denom
)
+
(
decay
*
pj
);
pj
=
pj
-
(
step_size
*
lamb_coeff
*
update
);
p
[
j
]
=
pj
;
if
(
p_copy
!=
NULL
)
p_copy
[
j
]
=
(
GRAD_T
)
pj
;
}
}
void
fused_lamb_cuda
(
at
::
Tensor
&
p
,
at
::
Tensor
&
p_copy
,
at
::
Tensor
&
m
,
at
::
Tensor
&
v
,
at
::
Tensor
&
g
,
float
lr
,
float
beta1
,
float
beta2
,
float
max_coeff
,
float
min_coeff
,
float
eps
,
float
grad_scale
,
int
step
,
int
mode
,
int
bias_correction
,
float
decay
,
at
::
Tensor
&
w_l2_i
,
at
::
Tensor
&
u_l2_i
,
at
::
Tensor
&
lamb_coeff
)
{
// using namespace at;
// Get tensor size
int
tsize
=
p
.
numel
();
// Determine #threads and #blocks
const
int
threadsPerBlock
=
512
;
int
num_blocks
=
(
tsize
+
threadsPerBlock
-
1
)
/
threadsPerBlock
;
if
(
num_blocks
>
512
)
num_blocks
=
512
;
int
smemsize
=
0
;
if
(
p
.
type
().
scalarType
()
==
at
::
ScalarType
::
Double
)
smemsize
=
2
*
threadsPerBlock
*
sizeof
(
double
);
else
smemsize
=
2
*
threadsPerBlock
*
sizeof
(
float
);
const
dim3
blocks
(
num_blocks
);
const
dim3
threads
(
threadsPerBlock
);
AT_ASSERTM
(
at
::
cuda
::
detail
::
canUse32BitIndexMath
(
p
),
"parameter tensor is too large to be indexed with int32"
);
// Constants
float
step_size
=
0
;
if
(
bias_correction
==
1
)
{
const
float
bias_correction1
=
1
-
std
::
pow
(
beta1
,
step
);
const
float
bias_correction2
=
1
-
std
::
pow
(
beta2
,
step
);
step_size
=
lr
*
std
::
sqrt
(
bias_correction2
)
/
bias_correction1
;
}
else
{
step_size
=
lr
;
}
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
if
(
g
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
)
{
// all other values should be fp32 for half gradients
AT_ASSERTM
(
p
.
type
().
scalarType
()
==
at
::
ScalarType
::
Float
,
"expected parameter to be of float type"
);
// dispatch is done on the gradient type
using
namespace
at
;
// prevents "toString is undefined" errors
AT_DISPATCH_FLOATING_TYPES_AND_HALF
(
g
.
scalar_type
(),
"lamb_cuda_kernel"
,
([
&
]
{
using
accscalar_t
=
at
::
acc_type
<
scalar_t
,
true
>
;
lamb_cuda_kernel_part1
<
accscalar_t
,
scalar_t
,
threadsPerBlock
>
<<<
blocks
,
threadsPerBlock
,
smemsize
,
stream
>>>
(
p
.
data
<
accscalar_t
>
(),
p_copy
.
numel
()
?
p_copy
.
data
<
scalar_t
>
()
:
NULL
,
m
.
data
<
accscalar_t
>
(),
v
.
data
<
accscalar_t
>
(),
g
.
data
<
scalar_t
>
(),
beta1
,
beta2
,
eps
,
grad_scale
,
step_size
,
tsize
,
(
adamMode_t
)
mode
,
decay
,
w_l2_i
.
data
<
accscalar_t
>
(),
u_l2_i
.
data
<
accscalar_t
>
());
lamb_cuda_kernel_part2
<
accscalar_t
,
scalar_t
,
threadsPerBlock
>
<<<
1
,
threadsPerBlock
,
smemsize
,
stream
>>>
(
num_blocks
,
w_l2_i
.
data
<
accscalar_t
>
(),
u_l2_i
.
data
<
accscalar_t
>
());
lamb_cuda_kernel_part3
<
accscalar_t
,
scalar_t
>
<<<
blocks
,
threadsPerBlock
,
smemsize
,
stream
>>>
(
p
.
data
<
accscalar_t
>
(),
p_copy
.
numel
()
?
p_copy
.
data
<
scalar_t
>
()
:
NULL
,
m
.
data
<
accscalar_t
>
(),
v
.
data
<
accscalar_t
>
(),
g
.
data
<
scalar_t
>
(),
beta1
,
beta2
,
max_coeff
,
min_coeff
,
eps
,
grad_scale
,
step_size
,
tsize
,
(
adamMode_t
)
mode
,
decay
,
w_l2_i
.
data
<
accscalar_t
>
(),
u_l2_i
.
data
<
accscalar_t
>
(),
lamb_coeff
.
data
<
accscalar_t
>
());
}));
}
else
{
using
namespace
at
;
AT_DISPATCH_FLOATING_TYPES
(
g
.
scalar_type
(),
"lamb_cuda_kernel"
,
([
&
]
{
lamb_cuda_kernel_part1
<
scalar_t
,
scalar_t
,
threadsPerBlock
>
<<<
blocks
,
threadsPerBlock
,
smemsize
,
stream
>>>
(
p
.
data
<
scalar_t
>
(),
NULL
,
// don't output p_copy for fp32, it's wasted write
m
.
data
<
scalar_t
>
(),
v
.
data
<
scalar_t
>
(),
g
.
data
<
scalar_t
>
(),
beta1
,
beta2
,
eps
,
grad_scale
,
step_size
,
tsize
,
(
adamMode_t
)
mode
,
decay
,
w_l2_i
.
data
<
scalar_t
>
(),
u_l2_i
.
data
<
scalar_t
>
());
lamb_cuda_kernel_part2
<
scalar_t
,
scalar_t
,
threadsPerBlock
>
<<<
1
,
threadsPerBlock
,
smemsize
,
stream
>>>
(
num_blocks
,
w_l2_i
.
data
<
scalar_t
>
(),
u_l2_i
.
data
<
scalar_t
>
());
lamb_cuda_kernel_part3
<
scalar_t
,
scalar_t
>
<<<
blocks
,
threadsPerBlock
,
smemsize
,
stream
>>>
(
p
.
data
<
scalar_t
>
(),
NULL
,
// don't output p_copy for fp32, it's wasted write
m
.
data
<
scalar_t
>
(),
v
.
data
<
scalar_t
>
(),
g
.
data
<
scalar_t
>
(),
beta1
,
beta2
,
max_coeff
,
min_coeff
,
eps
,
grad_scale
,
step_size
,
tsize
,
(
adamMode_t
)
mode
,
decay
,
w_l2_i
.
data
<
scalar_t
>
(),
u_l2_i
.
data
<
scalar_t
>
(),
lamb_coeff
.
data
<
scalar_t
>
());
}));
}
THCudaCheck
(
cudaGetLastError
());
}
// template __device__ void reduce_two_vectors_in_register<float,512>(float a, float b, float* g_a,
// float* g_b, cg::grid_group &cgg);
deepspeed/ops/csrc/lamb/hip/fused_lamb_hip.cpp
0 → 100644
View file @
eadbbe09
/* Copyright 2019 The Microsoft DeepSpeed Team */
#include <torch/extension.h>
// CUDA forward declaration
void
fused_lamb_cuda
(
at
::
Tensor
&
p
,
at
::
Tensor
&
p_copy
,
at
::
Tensor
&
m
,
at
::
Tensor
&
v
,
at
::
Tensor
&
g
,
float
lr
,
float
beta1
,
float
beta2
,
float
max_coeff
,
float
min_coeff
,
float
eps
,
float
grad_scale
,
int
step
,
int
mode
,
int
bias_correction
,
float
decay
,
at
::
Tensor
&
w_l2_i
,
at
::
Tensor
&
u_l2_i
,
at
::
Tensor
&
lamb_coeff_val
);
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) \
CHECK_CUDA(x); \
CHECK_CONTIGUOUS(x)
// C++ interface
at
::
Tensor
lamb
(
at
::
Tensor
&
p
,
at
::
Tensor
&
p_copy
,
at
::
Tensor
&
m
,
at
::
Tensor
&
v
,
at
::
Tensor
&
g
,
float
lr
,
float
beta1
,
float
beta2
,
float
max_coeff
,
float
min_coeff
,
float
eps
,
float
grad_scale
,
int
step
,
int
mode
,
int
bias_correction
,
float
decay
)
{
CHECK_INPUT
(
p
);
if
(
p_copy
.
numel
()
>
0
)
CHECK_INPUT
(
p_copy
);
CHECK_INPUT
(
m
);
CHECK_INPUT
(
v
);
CHECK_INPUT
(
g
);
int64_t
num_elem
=
p
.
numel
();
AT_ASSERTM
(
m
.
numel
()
==
num_elem
,
"number of elements in m and p tensors should be equal"
);
AT_ASSERTM
(
v
.
numel
()
==
num_elem
,
"number of elements in v and p tensors should be equal"
);
AT_ASSERTM
(
g
.
numel
()
==
num_elem
,
"number of elements in g and p tensors should be equal"
);
AT_ASSERTM
(
p_copy
.
numel
()
==
num_elem
||
p_copy
.
numel
()
==
0
,
"number of elements in p_copy and p tensors should be equal, or p_copy should be empty"
);
// intermediate for weight L2 reduction
// make sure that the threads per block is at least 512 during the kernel launch otherwise the
// behavious is unexpected
at
::
Tensor
w_l2_i
=
at
::
empty
(
{
512
},
p
.
options
().
dtype
(
p
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
?
at
::
ScalarType
::
Float
:
p
.
type
().
scalarType
()));
// intermediate for update L2 reduction
// make sure that the threads per block is at least 512 during the kernel launch otherwise the
// behavious is unexpected
at
::
Tensor
u_l2_i
=
at
::
empty
(
{
512
},
p
.
options
().
dtype
(
p
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
?
at
::
ScalarType
::
Float
:
p
.
type
().
scalarType
()));
at
::
Tensor
lamb_coeff_val
=
at
::
empty
(
{
1
},
p
.
options
().
dtype
(
p
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
?
at
::
ScalarType
::
Float
:
p
.
type
().
scalarType
()));
fused_lamb_cuda
(
p
,
p_copy
,
m
,
v
,
g
,
lr
,
beta1
,
beta2
,
max_coeff
,
min_coeff
,
eps
,
grad_scale
,
step
,
mode
,
bias_correction
,
decay
,
w_l2_i
,
u_l2_i
,
lamb_coeff_val
);
return
lamb_coeff_val
;
}
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"lamb"
,
&
lamb
,
"Adam optimized CUDA implementation with LAMB."
);
}
deepspeed/ops/csrc/lamb/hip/fused_lamb_hip_kernel.hip
0 → 100644
View file @
eadbbe09
/* Copyright 2019 The Microsoft DeepSpeed Team */
#include <hip/hip_runtime.h>
#include <hip/hip_runtime.h>
#include <stdio.h>
#include <cmath>
#include "ATen/ATen.h"
#include "ATen/TensorUtils.h"
#include "ATen/hip/HIPContext.h"
#include "ATen/hip/detail/IndexUtils.cuh"
//#include "ATen/Type.h"
#include <THH/THHGeneral.h>
#include "ATen/AccumulateType.h"
#include <iostream>
//#include <helper_functions.h>
#include <cooperative_groups.h>
#include <hip/hip_runtime_api.h>
#include <stdio.h>
namespace cg = cooperative_groups;
// Utility class used to avoid linker errors with extern
// unsized shared memory arrays with templated type
namespace {
// This is the un-specialized struct. Note that we prevent instantiation of this
// struct by putting an undefined symbol in the function body so it won't compile.
template <typename T>
struct SharedMemory {
// Ensure that we won't compile any un-specialized types
__device__ inline operator T*()
{
extern __device__ void error(void);
error();
return NULL;
}
};
template <>
struct SharedMemory<float> {
__device__ inline operator float*()
{
HIP_DYNAMIC_SHARED( float, s_float)
return s_float;
}
};
template <>
struct SharedMemory<double> {
__device__ inline operator double*()
{
HIP_DYNAMIC_SHARED( double, s_double)
return s_double;
}
};
} // namespace
#include "type_shim.h"
typedef enum {
ADAM_MODE_0 = 0, // eps under square root
ADAM_MODE_1 = 1 // eps outside square root
} adamMode_t;
// s_a and s_b are in shared memory
// g_a and g_b are in shared memory
template <typename T, int blockSize>
__device__ void reduce_block_in_shared_memory(T* s_a, T* s_b, T* g_a, T* g_b)
{
// Handle to thread block group
cg::thread_block cta = cg::this_thread_block();
// perform block reduction in shared memory,
unsigned int tid = cta.thread_rank();
T a_sum = s_a[tid];
T b_sum = s_b[tid];
cg::sync(cta);
// do reduction in shared mem
if ((blockSize >= 512) && (tid < 256)) {
s_a[tid] = a_sum = a_sum + s_a[tid + 256];
s_b[tid] = b_sum = b_sum + s_b[tid + 256];
}
cg::sync(cta);
if ((blockSize >= 256) && (tid < 128)) {
s_a[tid] = a_sum = a_sum + s_a[tid + 128];
s_b[tid] = b_sum = b_sum + s_b[tid + 128];
}
cg::sync(cta);
if ((blockSize >= 128) && (tid < 64)) {
s_a[tid] = a_sum = a_sum + s_a[tid + 64];
s_b[tid] = b_sum = b_sum + s_b[tid + 64];
}
cg::sync(cta);
#if (__CUDA_ARCH__ >= 300)
if (tid < 32) {
cg::coalesced_group active = cg::coalesced_threads();
// Fetch final intermediate sum from 2nd warp
if (blockSize >= 64) {
a_sum = a_sum + s_a[tid + 32];
b_sum = b_sum + s_b[tid + 32];
}
// Reduce final warp using shuffle
for (int offset = warpSize / 2; offset > 0; offset /= 2) {
a_sum += active.shfl_down(a_sum, offset);
b_sum += active.shfl_down(b_sum, offset);
}
}
#else
if ((blockSize >= 64) && (tid < 32)) {
s_a[tid] = a_sum = a_sum + s_a[tid + 32];
s_b[tid] = b_sum = b_sum + s_b[tid + 32];
}
cg::sync(cta);
if ((blockSize >= 32) && (tid < 16)) {
s_a[tid] = a_sum = a_sum + s_a[tid + 16];
s_b[tid] = b_sum = b_sum + s_b[tid + 16];
}
cg::sync(cta);
if ((blockSize >= 16) && (tid < 8)) {
s_a[tid] = a_sum = a_sum + s_a[tid + 8];
s_b[tid] = b_sum = b_sum + s_b[tid + 8];
}
cg::sync(cta);
if ((blockSize >= 8) && (tid < 4)) {
s_a[tid] = a_sum = a_sum + s_a[tid + 4];
s_b[tid] = b_sum = b_sum + s_b[tid + 4];
}
cg::sync(cta);
if ((blockSize >= 4) && (tid < 2)) {
s_a[tid] = a_sum = a_sum + s_a[tid + 2];
s_b[tid] = b_sum = b_sum + s_b[tid + 2];
}
cg::sync(cta);
if ((blockSize >= 2) && (tid < 1)) {
s_a[tid] = a_sum = a_sum + s_a[tid + 1];
s_b[tid] = b_sum = b_sum + s_b[tid + 1];
}
cg::sync(cta);
#endif
// write result for this block to global mem
if (tid == 0) {
g_a[blockIdx.x] = (T)a_sum;
g_b[blockIdx.x] = (T)b_sum;
}
}
template <typename T, int blockSize>
__device__ void reduce_two_vectors_in_register(T a, T b, T* g_a, T* g_b)
{
const int threadIdInBlock = cg::this_thread_block().thread_rank();
T* s_a = SharedMemory<T>();
T* s_b = SharedMemory<T>() + cg::this_thread_block().size();
s_a[threadIdInBlock] = a;
s_b[threadIdInBlock] = b;
reduce_block_in_shared_memory<T, blockSize>(s_a, s_b, g_a, g_b);
}
template <typename T, typename GRAD_T, int blockSize>
__global__ void lamb_cuda_kernel_part1(
T* __restrict__ p,
GRAD_T* __restrict__ p_copy, // For mixed precision training, pass NULL if not needed
T* __restrict__ m,
T* __restrict__ v,
const GRAD_T* __restrict__ g,
const float b1,
const float b2,
const float eps,
const float grad_scale,
const float step_size,
const size_t tsize,
adamMode_t mode,
const float decay,
T* __restrict__ w_l2_i,
T* __restrict__ u_l2_i)
{
// Assuming 2D grids and 2D blocks
const int blockId = gridDim.x * blockIdx.y + blockIdx.x;
const int threadsPerBlock = blockDim.x * blockDim.y;
const int threadIdInBlock = cg::this_thread_block().thread_rank();
const int i = (blockId * threadsPerBlock + threadIdInBlock);
const int totThreads = gridDim.x * gridDim.y * threadsPerBlock;
T reg_w = 0;
T reg_u = 0;
for (int j = i; j < tsize; j += totThreads) {
T scaled_grad = g[j] / grad_scale;
T pj = p[j];
m[j] = b1 * m[j] + (1 - b1) * scaled_grad;
v[j] = b2 * v[j] + (1 - b2) * scaled_grad * scaled_grad;
float denom;
if (mode == ADAM_MODE_0)
denom = sqrtf(v[j] + eps);
else // Mode 1
denom = sqrtf(v[j]) + eps;
T update = (m[j] / denom) + (decay * p[j]);
reg_u += update * update;
reg_w += pj * pj;
}
reduce_two_vectors_in_register<T, blockSize>(reg_w, reg_u, w_l2_i, u_l2_i);
}
template <typename T, typename GRAD_T, int blockSize>
__global__ void lamb_cuda_kernel_part2(const size_t tsize, T* __restrict__ g_a, T* __restrict__ g_b)
{
T* s_a = SharedMemory<T>();
T* s_b = SharedMemory<T>() + cg::this_thread_block().size();
const int threadIdInBlock = cg::this_thread_block().thread_rank();
s_a[threadIdInBlock] = g_a[threadIdInBlock];
s_b[threadIdInBlock] = g_b[threadIdInBlock];
if (threadIdInBlock >= tsize) {
s_a[threadIdInBlock] = 0.0;
s_b[threadIdInBlock] = 0.0;
}
reduce_block_in_shared_memory<T, blockSize>(s_a, s_b, g_a, g_b);
}
template <typename T, typename GRAD_T>
__global__ void lamb_cuda_kernel_part3(
T* __restrict__ p,
GRAD_T* __restrict__ p_copy, // For mixed precision training, pass NULL if not needed
T* __restrict__ m,
T* __restrict__ v,
const GRAD_T* __restrict__ g,
const float b1,
const float b2,
const float max_coeff,
const float min_coeff,
const float eps,
const float grad_scale,
const float step_size,
const size_t tsize,
adamMode_t mode,
const float decay,
T* __restrict__ w_l2_i,
T* __restrict__ u_l2_i,
T* __restrict__ lamb_coeff_val)
{
// Assuming 2D grids and 2D blocks
const int blockId = gridDim.x * blockIdx.y + blockIdx.x;
const int threadsPerBlock = blockDim.x * blockDim.y;
const int threadIdInBlock = cg::this_thread_block().thread_rank();
const int i = (blockId * threadsPerBlock + threadIdInBlock);
const int totThreads = gridDim.x * gridDim.y * threadsPerBlock;
T reg_w = sqrtf(w_l2_i[0]);
T reg_u = sqrtf(u_l2_i[0]);
float lamb_coeff = 1.0;
if (reg_w != 0 and reg_u != 0) {
lamb_coeff = reg_w / reg_u;
if (lamb_coeff > max_coeff) { lamb_coeff = max_coeff; }
if (lamb_coeff < min_coeff) { lamb_coeff = min_coeff; }
}
if (blockId == 0 and threadIdInBlock == 0) {
lamb_coeff_val[0] = lamb_coeff;
// printf("Cuda Lamb Coeff is %.6f \n",lamb_coeff);
}
for (int j = i; j < tsize; j += totThreads) {
T pj = (float)p[j];
T mj = m[j];
T vj = v[j];
float denom;
if (mode == ADAM_MODE_0)
denom = sqrtf(vj + eps);
else // Mode 1
denom = sqrtf(vj) + eps;
T update = (mj / denom) + (decay * pj);
pj = pj - (step_size * lamb_coeff * update);
p[j] = pj;
if (p_copy != NULL) p_copy[j] = (GRAD_T)pj;
}
}
void fused_lamb_cuda(at::Tensor& p,
at::Tensor& p_copy,
at::Tensor& m,
at::Tensor& v,
at::Tensor& g,
float lr,
float beta1,
float beta2,
float max_coeff,
float min_coeff,
float eps,
float grad_scale,
int step,
int mode,
int bias_correction,
float decay,
at::Tensor& w_l2_i,
at::Tensor& u_l2_i,
at::Tensor& lamb_coeff)
{
// using namespace at;
// Get tensor size
int tsize = p.numel();
// Determine #threads and #blocks
const int threadsPerBlock = 512;
int num_blocks = (tsize + threadsPerBlock - 1) / threadsPerBlock;
if (num_blocks > 512) num_blocks = 512;
int smemsize = 0;
if (p.type().scalarType() == at::ScalarType::Double)
smemsize = 2 * threadsPerBlock * sizeof(double);
else
smemsize = 2 * threadsPerBlock * sizeof(float);
const dim3 blocks(num_blocks);
const dim3 threads(threadsPerBlock);
AT_ASSERTM(at::cuda::detail::canUse32BitIndexMath(p),
"parameter tensor is too large to be indexed with int32");
// Constants
float step_size = 0;
if (bias_correction == 1) {
const float bias_correction1 = 1 - ::pow(beta1, step);
const float bias_correction2 = 1 - ::pow(beta2, step);
step_size = lr * std::sqrt(bias_correction2) / bias_correction1;
} else {
step_size = lr;
}
hipStream_t stream = at::hip::getCurrentHIPStreamMasqueradingAsCUDA();
if (g.type().scalarType() == at::ScalarType::Half) {
// all other values should be fp32 for half gradients
AT_ASSERTM(p.type().scalarType() == at::ScalarType::Float,
"expected parameter to be of float type");
// dispatch is done on the gradient type
using namespace at; // prevents "toString is undefined" errors
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
g.scalar_type(), "lamb_cuda_kernel", ([&] {
using accscalar_t = at::acc_type<scalar_t, true>;
hipLaunchKernelGGL(( lamb_cuda_kernel_part1<accscalar_t, scalar_t, threadsPerBlock>)
, dim3(blocks), dim3(threadsPerBlock), smemsize, stream,
p.data<accscalar_t>(),
p_copy.numel() ? p_copy.data<scalar_t>() : NULL,
m.data<accscalar_t>(),
v.data<accscalar_t>(),
g.data<scalar_t>(),
beta1,
beta2,
eps,
grad_scale,
step_size,
tsize,
(adamMode_t)mode,
decay,
w_l2_i.data<accscalar_t>(),
u_l2_i.data<accscalar_t>());
hipLaunchKernelGGL(( lamb_cuda_kernel_part2<accscalar_t, scalar_t, threadsPerBlock>)
, dim3(1), dim3(threadsPerBlock), smemsize, stream,
num_blocks, w_l2_i.data<accscalar_t>(), u_l2_i.data<accscalar_t>());
hipLaunchKernelGGL(( lamb_cuda_kernel_part3<accscalar_t, scalar_t>)
, dim3(blocks), dim3(threadsPerBlock), smemsize, stream,
p.data<accscalar_t>(),
p_copy.numel() ? p_copy.data<scalar_t>() : NULL,
m.data<accscalar_t>(),
v.data<accscalar_t>(),
g.data<scalar_t>(),
beta1,
beta2,
max_coeff,
min_coeff,
eps,
grad_scale,
step_size,
tsize,
(adamMode_t)mode,
decay,
w_l2_i.data<accscalar_t>(),
u_l2_i.data<accscalar_t>(),
lamb_coeff.data<accscalar_t>());
}));
} else {
using namespace at;
AT_DISPATCH_FLOATING_TYPES(
g.scalar_type(), "lamb_cuda_kernel", ([&] {
hipLaunchKernelGGL(( lamb_cuda_kernel_part1<scalar_t, scalar_t, threadsPerBlock>)
, dim3(blocks), dim3(threadsPerBlock), smemsize, stream,
p.data<scalar_t>(),
NULL, // don't output p_copy for fp32, it's wasted write
m.data<scalar_t>(),
v.data<scalar_t>(),
g.data<scalar_t>(),
beta1,
beta2,
eps,
grad_scale,
step_size,
tsize,
(adamMode_t)mode,
decay,
w_l2_i.data<scalar_t>(),
u_l2_i.data<scalar_t>());
hipLaunchKernelGGL(( lamb_cuda_kernel_part2<scalar_t, scalar_t, threadsPerBlock>)
, dim3(1), dim3(threadsPerBlock), smemsize, stream,
num_blocks, w_l2_i.data<scalar_t>(), u_l2_i.data<scalar_t>());
hipLaunchKernelGGL(( lamb_cuda_kernel_part3<scalar_t, scalar_t>)
, dim3(blocks), dim3(threadsPerBlock), smemsize, stream,
p.data<scalar_t>(),
NULL, // don't output p_copy for fp32, it's wasted write
m.data<scalar_t>(),
v.data<scalar_t>(),
g.data<scalar_t>(),
beta1,
beta2,
max_coeff,
min_coeff,
eps,
grad_scale,
step_size,
tsize,
(adamMode_t)mode,
decay,
w_l2_i.data<scalar_t>(),
u_l2_i.data<scalar_t>(),
lamb_coeff.data<scalar_t>());
}));
}
THCudaCheck(hipGetLastError());
}
// template __device__ void reduce_two_vectors_in_register<float,512>(float a, float b, float* g_a,
// float* g_b, cg::grid_group &cgg);
deepspeed/ops/csrc/sparse_attention/hip/utils.cpp
0 → 100644
View file @
eadbbe09
// DeepSpeed note, code taken & adapted from commit 9aa94789f13ada713af36cfd8cca2fc9a7f6b79a
// https://github.com/ptillet/torch-blocksparse/blob/master/csrc/utils.cpp
#include <torch/extension.h>
#include <string>
#include <tuple>
#include <vector>
#ifdef _OPENMP
#include <omp.h>
#endif
typedef
std
::
vector
<
std
::
tuple
<
int
,
torch
::
Tensor
>>
ret_t
;
void
segment_blocks
(
torch
::
Tensor
layout
,
torch
::
Tensor
idx
,
torch
::
Tensor
scratch
,
int
max_width
,
ret_t
&
ret
)
{
size_t
H
=
layout
.
size
(
0
);
size_t
M
=
layout
.
size
(
1
);
size_t
N
=
layout
.
size
(
2
);
torch
::
Tensor
tmp
=
torch
::
zeros_like
(
layout
);
auto
_tmp
=
tmp
.
accessor
<
int
,
3
>
();
auto
_layout
=
layout
.
accessor
<
int
,
3
>
();
auto
_idx
=
idx
.
accessor
<
int
,
3
>
();
auto
_scratch
=
scratch
.
accessor
<
int
,
3
>
();
std
::
vector
<
int
>
current
(
H
,
0
);
#ifdef _OPENMP
#pragma omp parallel for
#endif
for
(
size_t
h
=
0
;
h
<
H
;
h
++
)
{
// surrounding indices
std
::
vector
<
int
>
ii_left
(
max_width
,
-
1
);
std
::
vector
<
std
::
vector
<
int
>>
ii_top
(
max_width
,
std
::
vector
<
int
>
(
N
,
-
1
));
for
(
size_t
m
=
0
;
m
<
M
;
m
++
)
{
for
(
size_t
n
=
0
;
n
<
N
;
n
++
)
{
int
v
=
_layout
[
h
][
m
][
n
];
if
(
v
==
0
)
continue
;
int
n_left
=
ii_left
[
max_width
-
1
];
int
m_top
=
ii_top
[
max_width
-
1
][
n
];
int
top
=
(
m_top
>=
0
)
?
_tmp
[
h
][
m_top
][
n
]
:
0
;
int
left
=
(
n_left
>=
0
)
?
_tmp
[
h
][
m
][
n_left
]
:
0
;
int
topleft
=
(
m_top
>=
0
&&
n_left
>=
0
)
?
_tmp
[
h
][
m_top
][
n_left
]
:
0
;
int
width
=
std
::
min
(
left
,
std
::
min
(
top
,
topleft
))
+
1
;
// reset width if blocks cannot be
// packed together (i.e., there's a 1 "in the middle")
for
(
int
nn
=
n_left
+
1
;
nn
<
n
;
nn
++
)
if
(
ii_top
[
max_width
-
1
][
nn
]
>
ii_top
[
max_width
-
1
][
n
])
width
=
1
;
_tmp
[
h
][
m
][
n
]
=
width
;
// update n_left ring buffer
for
(
int
k
=
0
;
k
<
max_width
-
1
;
k
++
)
ii_left
[
k
]
=
ii_left
[
k
+
1
];
ii_left
[
max_width
-
1
]
=
n
;
// update ii_top ring buffer
for
(
int
k
=
0
;
k
<
max_width
-
1
;
k
++
)
ii_top
[
k
][
n
]
=
ii_top
[
k
+
1
][
n
];
ii_top
[
max_width
-
1
][
n
]
=
m
;
// block is too small -- skip
if
(
width
!=
max_width
)
continue
;
// retained blocks are set to zeros
for
(
size_t
km
=
0
;
km
<
max_width
;
km
++
)
for
(
size_t
kn
=
0
;
kn
<
max_width
;
kn
++
)
{
int
mm
=
ii_top
[
km
][
n
];
int
nn
=
ii_left
[
kn
];
if
(
mm
<
0
||
nn
<
0
)
continue
;
_layout
[
h
][
mm
][
nn
]
=
0
;
_tmp
[
h
][
mm
][
nn
]
=
0
;
_scratch
[
h
][
current
[
h
]][
0
]
=
(
int
)
h
;
_scratch
[
h
][
current
[
h
]][
1
]
=
(
int
)
mm
;
_scratch
[
h
][
current
[
h
]][
2
]
=
(
int
)
nn
;
_scratch
[
h
][
current
[
h
]][
3
]
=
_idx
[
h
][
mm
][
nn
];
current
[
h
]
++
;
}
}
}
}
std
::
vector
<
torch
::
Tensor
>
to_cat
;
for
(
size_t
h
=
0
;
h
<
H
;
h
++
)
if
(
current
[
h
]
>
0
)
to_cat
.
push_back
(
scratch
[
h
].
slice
(
0
,
0
,
current
[
h
]));
if
(
!
to_cat
.
empty
())
ret
.
push_back
({
max_width
,
torch
::
cat
(
to_cat
)});
}
ret_t
sdd_segment
(
torch
::
Tensor
layout
,
int
start_width
)
{
ret_t
ret
;
// block index
torch
::
Tensor
idx
=
torch
::
zeros_like
(
layout
);
int
current
=
0
;
size_t
H
=
layout
.
size
(
0
);
size_t
M
=
layout
.
size
(
1
);
size_t
N
=
layout
.
size
(
2
);
auto
_layout
=
layout
.
accessor
<
int
,
3
>
();
auto
_idx
=
idx
.
accessor
<
int
,
3
>
();
for
(
size_t
h
=
0
;
h
<
H
;
h
++
)
for
(
size_t
m
=
0
;
m
<
M
;
m
++
)
for
(
size_t
n
=
0
;
n
<
N
;
n
++
)
{
if
(
_layout
[
h
][
m
][
n
]
==
0
)
continue
;
_idx
[
h
][
m
][
n
]
=
current
++
;
}
// scratch memory
torch
::
Tensor
scratch
=
torch
::
empty
({
H
,
layout
.
sum
().
item
<
int
>
(),
4
},
layout
.
dtype
());
for
(
int
max_width
=
start_width
;
max_width
>
0
;
max_width
/=
2
)
segment_blocks
(
layout
,
idx
,
scratch
,
max_width
,
ret
);
return
ret
;
}
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"sdd_segment"
,
&
sdd_segment
,
"SDD segmentation handler"
);
}
deepspeed/ops/csrc/sparse_attention/utils.cpp
0 → 100644
View file @
eadbbe09
// DeepSpeed note, code taken & adapted from commit 9aa94789f13ada713af36cfd8cca2fc9a7f6b79a
// https://github.com/ptillet/torch-blocksparse/blob/master/csrc/utils.cpp
#include <torch/extension.h>
#include <string>
#include <tuple>
#include <vector>
#ifdef _OPENMP
#include <omp.h>
#endif
typedef
std
::
vector
<
std
::
tuple
<
int
,
torch
::
Tensor
>>
ret_t
;
void
segment_blocks
(
torch
::
Tensor
layout
,
torch
::
Tensor
idx
,
torch
::
Tensor
scratch
,
int
max_width
,
ret_t
&
ret
)
{
size_t
H
=
layout
.
size
(
0
);
size_t
M
=
layout
.
size
(
1
);
size_t
N
=
layout
.
size
(
2
);
torch
::
Tensor
tmp
=
torch
::
zeros_like
(
layout
);
auto
_tmp
=
tmp
.
accessor
<
int
,
3
>
();
auto
_layout
=
layout
.
accessor
<
int
,
3
>
();
auto
_idx
=
idx
.
accessor
<
int
,
3
>
();
auto
_scratch
=
scratch
.
accessor
<
int
,
3
>
();
std
::
vector
<
int
>
current
(
H
,
0
);
#ifdef _OPENMP
#pragma omp parallel for
#endif
for
(
size_t
h
=
0
;
h
<
H
;
h
++
)
{
// surrounding indices
std
::
vector
<
int
>
ii_left
(
max_width
,
-
1
);
std
::
vector
<
std
::
vector
<
int
>>
ii_top
(
max_width
,
std
::
vector
<
int
>
(
N
,
-
1
));
for
(
size_t
m
=
0
;
m
<
M
;
m
++
)
{
for
(
size_t
n
=
0
;
n
<
N
;
n
++
)
{
int
v
=
_layout
[
h
][
m
][
n
];
if
(
v
==
0
)
continue
;
int
n_left
=
ii_left
[
max_width
-
1
];
int
m_top
=
ii_top
[
max_width
-
1
][
n
];
int
top
=
(
m_top
>=
0
)
?
_tmp
[
h
][
m_top
][
n
]
:
0
;
int
left
=
(
n_left
>=
0
)
?
_tmp
[
h
][
m
][
n_left
]
:
0
;
int
topleft
=
(
m_top
>=
0
&&
n_left
>=
0
)
?
_tmp
[
h
][
m_top
][
n_left
]
:
0
;
int
width
=
std
::
min
(
left
,
std
::
min
(
top
,
topleft
))
+
1
;
// reset width if blocks cannot be
// packed together (i.e., there's a 1 "in the middle")
for
(
int
nn
=
n_left
+
1
;
nn
<
n
;
nn
++
)
if
(
ii_top
[
max_width
-
1
][
nn
]
>
ii_top
[
max_width
-
1
][
n
])
width
=
1
;
_tmp
[
h
][
m
][
n
]
=
width
;
// update n_left ring buffer
for
(
int
k
=
0
;
k
<
max_width
-
1
;
k
++
)
ii_left
[
k
]
=
ii_left
[
k
+
1
];
ii_left
[
max_width
-
1
]
=
n
;
// update ii_top ring buffer
for
(
int
k
=
0
;
k
<
max_width
-
1
;
k
++
)
ii_top
[
k
][
n
]
=
ii_top
[
k
+
1
][
n
];
ii_top
[
max_width
-
1
][
n
]
=
m
;
// block is too small -- skip
if
(
width
!=
max_width
)
continue
;
// retained blocks are set to zeros
for
(
size_t
km
=
0
;
km
<
max_width
;
km
++
)
for
(
size_t
kn
=
0
;
kn
<
max_width
;
kn
++
)
{
int
mm
=
ii_top
[
km
][
n
];
int
nn
=
ii_left
[
kn
];
if
(
mm
<
0
||
nn
<
0
)
continue
;
_layout
[
h
][
mm
][
nn
]
=
0
;
_tmp
[
h
][
mm
][
nn
]
=
0
;
_scratch
[
h
][
current
[
h
]][
0
]
=
(
int
)
h
;
_scratch
[
h
][
current
[
h
]][
1
]
=
(
int
)
mm
;
_scratch
[
h
][
current
[
h
]][
2
]
=
(
int
)
nn
;
_scratch
[
h
][
current
[
h
]][
3
]
=
_idx
[
h
][
mm
][
nn
];
current
[
h
]
++
;
}
}
}
}
std
::
vector
<
torch
::
Tensor
>
to_cat
;
for
(
size_t
h
=
0
;
h
<
H
;
h
++
)
if
(
current
[
h
]
>
0
)
to_cat
.
push_back
(
scratch
[
h
].
slice
(
0
,
0
,
current
[
h
]));
if
(
!
to_cat
.
empty
())
ret
.
push_back
({
max_width
,
torch
::
cat
(
to_cat
)});
}
ret_t
sdd_segment
(
torch
::
Tensor
layout
,
int
start_width
)
{
ret_t
ret
;
// block index
torch
::
Tensor
idx
=
torch
::
zeros_like
(
layout
);
int
current
=
0
;
size_t
H
=
layout
.
size
(
0
);
size_t
M
=
layout
.
size
(
1
);
size_t
N
=
layout
.
size
(
2
);
auto
_layout
=
layout
.
accessor
<
int
,
3
>
();
auto
_idx
=
idx
.
accessor
<
int
,
3
>
();
for
(
size_t
h
=
0
;
h
<
H
;
h
++
)
for
(
size_t
m
=
0
;
m
<
M
;
m
++
)
for
(
size_t
n
=
0
;
n
<
N
;
n
++
)
{
if
(
_layout
[
h
][
m
][
n
]
==
0
)
continue
;
_idx
[
h
][
m
][
n
]
=
current
++
;
}
// scratch memory
torch
::
Tensor
scratch
=
torch
::
empty
({
H
,
layout
.
sum
().
item
<
int
>
(),
4
},
layout
.
dtype
());
for
(
int
max_width
=
start_width
;
max_width
>
0
;
max_width
/=
2
)
segment_blocks
(
layout
,
idx
,
scratch
,
max_width
,
ret
);
return
ret
;
}
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"sdd_segment"
,
&
sdd_segment
,
"SDD segmentation handler"
);
}
deepspeed/ops/csrc/transformer/cublas_wrappers.cu
0 → 100644
View file @
eadbbe09
#include "cublas_wrappers.h"
int
cublas_gemm_ex
(
cublasHandle_t
handle
,
cublasOperation_t
transa
,
cublasOperation_t
transb
,
int
m
,
int
n
,
int
k
,
const
float
*
alpha
,
const
float
*
beta
,
const
float
*
A
,
const
float
*
B
,
float
*
C
,
cublasGemmAlgo_t
algo
)
{
cublasStatus_t
status
=
cublasGemmEx
(
handle
,
transa
,
transb
,
m
,
n
,
k
,
(
const
void
*
)
alpha
,
(
const
void
*
)
A
,
CUDA_R_32F
,
(
transa
==
CUBLAS_OP_N
)
?
m
:
k
,
(
const
void
*
)
B
,
CUDA_R_32F
,
(
transb
==
CUBLAS_OP_N
)
?
k
:
n
,
(
const
void
*
)
beta
,
C
,
CUDA_R_32F
,
m
,
CUDA_R_32F
,
algo
);
if
(
status
!=
CUBLAS_STATUS_SUCCESS
)
{
fprintf
(
stderr
,
"!!!! kernel execution error. (m: %d, n: %d, k: %d, error: %d)
\n
"
,
m
,
n
,
k
,
(
int
)
status
);
return
EXIT_FAILURE
;
}
return
0
;
}
int
cublas_gemm_ex
(
cublasHandle_t
handle
,
cublasOperation_t
transa
,
cublasOperation_t
transb
,
int
m
,
int
n
,
int
k
,
const
float
*
alpha
,
const
float
*
beta
,
const
__half
*
A
,
const
__half
*
B
,
__half
*
C
,
cublasGemmAlgo_t
algo
)
{
cublasStatus_t
status
=
cublasGemmEx
(
handle
,
transa
,
transb
,
m
,
n
,
k
,
(
const
void
*
)
alpha
,
(
const
void
*
)
A
,
CUDA_R_16F
,
(
transa
==
CUBLAS_OP_N
)
?
m
:
k
,
(
const
void
*
)
B
,
CUDA_R_16F
,
(
transb
==
CUBLAS_OP_N
)
?
k
:
n
,
(
const
void
*
)
beta
,
(
void
*
)
C
,
CUDA_R_16F
,
m
,
CUDA_R_32F
,
algo
);
if
(
status
!=
CUBLAS_STATUS_SUCCESS
)
{
fprintf
(
stderr
,
"!!!! kernel execution error. (m: %d, n: %d, k: %d, error: %d)
\n
"
,
m
,
n
,
k
,
(
int
)
status
);
return
EXIT_FAILURE
;
}
return
0
;
}
int
cublas_strided_batched_gemm
(
cublasHandle_t
handle
,
int
m
,
int
n
,
int
k
,
const
float
*
alpha
,
const
float
*
beta
,
const
float
*
A
,
const
float
*
B
,
float
*
C
,
cublasOperation_t
op_A
,
cublasOperation_t
op_B
,
int
stride_A
,
int
stride_B
,
int
stride_C
,
int
batch
,
cublasGemmAlgo_t
algo
)
{
cublasStatus_t
status
=
cublasGemmStridedBatchedEx
(
handle
,
op_A
,
op_B
,
m
,
n
,
k
,
alpha
,
A
,
CUDA_R_32F
,
(
op_A
==
CUBLAS_OP_N
)
?
m
:
k
,
stride_A
,
B
,
CUDA_R_32F
,
(
op_B
==
CUBLAS_OP_N
)
?
k
:
n
,
stride_B
,
beta
,
C
,
CUDA_R_32F
,
m
,
stride_C
,
batch
,
CUDA_R_32F
,
algo
);
if
(
status
!=
CUBLAS_STATUS_SUCCESS
)
{
fprintf
(
stderr
,
"!!!! kernel execution error. (batch: %d, m: %d, n: %d, k: %d, error: %d)
\n
"
,
batch
,
m
,
n
,
k
,
(
int
)
status
);
return
EXIT_FAILURE
;
}
return
0
;
}
int
cublas_strided_batched_gemm
(
cublasHandle_t
handle
,
int
m
,
int
n
,
int
k
,
const
float
*
alpha
,
const
float
*
beta
,
const
__half
*
A
,
const
__half
*
B
,
__half
*
C
,
cublasOperation_t
op_A
,
cublasOperation_t
op_B
,
int
stride_A
,
int
stride_B
,
int
stride_C
,
int
batch
,
cublasGemmAlgo_t
algo
)
{
cublasStatus_t
status
=
cublasGemmStridedBatchedEx
(
handle
,
op_A
,
op_B
,
m
,
n
,
k
,
alpha
,
A
,
CUDA_R_16F
,
(
op_A
==
CUBLAS_OP_N
)
?
m
:
k
,
stride_A
,
B
,
CUDA_R_16F
,
(
op_B
==
CUBLAS_OP_N
)
?
k
:
n
,
stride_B
,
beta
,
C
,
CUDA_R_16F
,
m
,
stride_C
,
batch
,
CUDA_R_32F
,
algo
);
if
(
status
!=
CUBLAS_STATUS_SUCCESS
)
{
fprintf
(
stderr
,
"!!!! kernel execution error. (m: %d, n: %d, k: %d, error: %d)
\n
"
,
m
,
n
,
k
,
(
int
)
status
);
return
EXIT_FAILURE
;
}
return
0
;
}
deepspeed/ops/csrc/transformer/dropout_kernels.cu
0 → 100755
View file @
eadbbe09
#include "custom_cuda_layers.h"
const
int
unroll_factor
=
4
;
__global__
void
dropout_kernel
(
const
int
N
,
const
float
ratio
,
float
*
out
,
const
float
*
Xdata
,
uint8_t
*
mask
,
std
::
pair
<
uint64_t
,
uint64_t
>
seed
)
{
const
float
scale
=
1.
/
(
1.
-
ratio
);
int
idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
curandStatePhilox4_32_10_t
state
;
curand_init
(
seed
.
first
,
idx
,
seed
.
second
,
&
state
);
CUDA_1D_KERNEL_LOOP
(
j
,
N
/
unroll_factor
)
{
float4
rand
=
curand_uniform4
(
&
state
);
uint8_t
m
[
unroll_factor
];
m
[
0
]
=
(
uint8_t
)(
rand
.
x
>
ratio
);
m
[
1
]
=
(
uint8_t
)(
rand
.
y
>
ratio
);
m
[
2
]
=
(
uint8_t
)(
rand
.
z
>
ratio
);
m
[
3
]
=
(
uint8_t
)(
rand
.
w
>
ratio
);
int
i
=
j
*
unroll_factor
;
mask
[
i
]
=
(
uint8_t
)
m
[
0
];
mask
[
i
+
1
]
=
(
uint8_t
)
m
[
1
];
mask
[
i
+
2
]
=
(
uint8_t
)
m
[
2
];
mask
[
i
+
3
]
=
(
uint8_t
)
m
[
3
];
out
[
i
]
=
Xdata
[
i
]
*
scale
*
m
[
0
];
out
[
i
+
1
]
=
Xdata
[
i
+
1
]
*
scale
*
m
[
1
];
out
[
i
+
2
]
=
Xdata
[
i
+
2
]
*
scale
*
m
[
2
];
out
[
i
+
3
]
=
Xdata
[
i
+
3
]
*
scale
*
m
[
3
];
}
int
high_index
=
((((
N
/
unroll_factor
)
-
1
)
/
blockDim
.
x
+
1
)
*
(
unroll_factor
*
blockDim
.
x
))
+
threadIdx
.
x
;
if
(
N
>
high_index
)
{
float4
rand
=
curand_uniform4
(
&
state
);
float
*
rand_data
=
&
(
rand
.
x
);
int
k
=
0
;
for
(
int
i
=
high_index
;
i
<
N
;
i
++
)
{
uint8_t
m
=
(
uint8_t
)(
rand_data
[
k
++
]
>
ratio
);
out
[
i
]
=
Xdata
[
i
]
*
scale
*
m
;
mask
[
i
]
=
m
;
}
}
}
__global__
void
dropout_kernel
(
const
int
N
,
const
float
ratio
,
__half
*
out
,
const
__half
*
Xdata
,
uint8_t
*
mask
,
std
::
pair
<
uint64_t
,
uint64_t
>
seed
)
{
const
float
scale
=
1.
/
(
1.
-
ratio
);
int
idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
curandStatePhilox4_32_10_t
state
;
curand_init
(
seed
.
first
,
idx
,
seed
.
second
,
&
state
);
#ifdef __STOCHASTIC_MODE__
const
__half2
h_scale
=
__float2half2_rn
(
scale
);
const
float2
*
x_cast
=
reinterpret_cast
<
const
float2
*>
(
Xdata
);
float2
*
out_cast
=
reinterpret_cast
<
float2
*>
(
out
);
uint32_t
*
mask_cast
=
reinterpret_cast
<
uint32_t
*>
(
mask
);
uint32_t
m_32
;
uint8_t
*
m
=
reinterpret_cast
<
uint8_t
*>
(
&
m_32
);
float2
result_f
;
__half2
*
result_h
=
reinterpret_cast
<
__half2
*>
(
&
result_f
);
__half2
mask_h
[
2
];
float2
mask_f
[
2
];
CUDA_1D_KERNEL_LOOP
(
j
,
N
/
unroll_factor
)
{
float2
x_f
=
x_cast
[
j
];
__half2
*
x_h
=
reinterpret_cast
<
__half2
*>
(
&
x_f
);
float4
rand
=
curand_uniform4
(
&
state
);
m
[
0
]
=
(
uint8_t
)(
rand
.
x
>
ratio
);
m
[
1
]
=
(
uint8_t
)(
rand
.
y
>
ratio
);
m
[
2
]
=
(
uint8_t
)(
rand
.
z
>
ratio
);
m
[
3
]
=
(
uint8_t
)(
rand
.
w
>
ratio
);
float
*
mask_f_data
=
&
mask_f
[
0
].
x
;
#pragma unroll
for
(
int
i
=
0
;
i
<
unroll_factor
;
i
++
)
mask_f_data
[
i
]
=
(
float
)(
m
[
i
]);
mask_h
[
0
]
=
__float22half2_rn
(
mask_f
[
0
]);
mask_h
[
1
]
=
__float22half2_rn
(
mask_f
[
1
]);
result_h
[
0
]
=
x_h
[
0
]
*
h_scale
*
mask_h
[
0
];
result_h
[
1
]
=
x_h
[
1
]
*
h_scale
*
mask_h
[
1
];
out_cast
[
j
]
=
result_f
;
mask_cast
[
j
]
=
m_32
;
}
#else
CUDA_1D_KERNEL_LOOP
(
j
,
N
/
unroll_factor
)
{
int
i
=
j
*
unroll_factor
;
const
__half2
*
vals_half
=
reinterpret_cast
<
const
__half2
*>
(
Xdata
+
i
);
float2
vals_half_f
[
2
];
vals_half_f
[
0
]
=
__half22float2
(
vals_half
[
0
]);
vals_half_f
[
1
]
=
__half22float2
(
vals_half
[
1
]);
uint8_t
m
[
unroll_factor
];
float4
rand
=
curand_uniform4
(
&
state
);
m
[
0
]
=
(
uint8_t
)(
rand
.
x
>
ratio
);
m
[
1
]
=
(
uint8_t
)(
rand
.
y
>
ratio
);
m
[
2
]
=
(
uint8_t
)(
rand
.
z
>
ratio
);
m
[
3
]
=
(
uint8_t
)(
rand
.
w
>
ratio
);
out
[
i
]
=
__float2half
(
vals_half_f
[
0
].
x
*
scale
*
m
[
0
]);
out
[
i
+
1
]
=
__float2half
(
vals_half_f
[
0
].
y
*
scale
*
m
[
1
]);
out
[
i
+
2
]
=
__float2half
(
vals_half_f
[
1
].
x
*
scale
*
m
[
2
]);
out
[
i
+
3
]
=
__float2half
(
vals_half_f
[
1
].
y
*
scale
*
m
[
3
]);
mask
[
i
]
=
m
[
0
];
mask
[
i
+
1
]
=
m
[
1
];
mask
[
i
+
2
]
=
m
[
2
];
mask
[
i
+
3
]
=
m
[
3
];
}
#endif
int
high_index
=
((((
N
/
unroll_factor
)
-
1
)
/
blockDim
.
x
+
1
)
*
(
unroll_factor
*
blockDim
.
x
))
+
threadIdx
.
x
;
if
(
N
>
high_index
)
{
float4
rand
=
curand_uniform4
(
&
state
);
float
*
rand_data
=
&
(
rand
.
x
);
int
k
=
0
;
for
(
int
i
=
high_index
;
i
<
N
;
i
++
)
{
uint8_t
m
=
(
uint8_t
)(
rand_data
[
k
++
]
>
ratio
);
out
[
i
]
=
__float2half
((
float
)
Xdata
[
i
]
*
scale
*
m
);
mask
[
i
]
=
m
;
}
}
}
__global__
void
dropout_kernel_bwd
(
const
int
N
,
const
float
ratio
,
const
float
*
Xdata
,
float
*
out
,
uint8_t
*
mask
,
std
::
pair
<
uint64_t
,
uint64_t
>
seed
)
{
const
float
scale
=
1.
/
(
1.
-
ratio
);
CUDA_1D_KERNEL_LOOP
(
j
,
N
/
unroll_factor
)
{
int
i
=
j
*
unroll_factor
;
out
[
i
]
=
mask
[
i
]
?
Xdata
[
i
]
*
scale
:
0.0
;
out
[
i
+
1
]
=
mask
[
i
+
1
]
?
Xdata
[
i
+
1
]
*
scale
:
0.0
;
out
[
i
+
2
]
=
mask
[
i
+
2
]
?
Xdata
[
i
+
2
]
*
scale
:
0.0
;
out
[
i
+
3
]
=
mask
[
i
+
3
]
?
Xdata
[
i
+
3
]
*
scale
:
0.0
;
}
int
high_index
=
((((
N
/
unroll_factor
)
-
1
)
/
blockDim
.
x
+
1
)
*
(
unroll_factor
*
blockDim
.
x
))
+
threadIdx
.
x
;
if
(
N
>
high_index
)
{
for
(
int
i
=
high_index
;
i
<
N
;
i
++
)
{
out
[
i
]
=
mask
[
i
]
?
Xdata
[
i
]
*
scale
:
0.0
;
}
}
}
__global__
void
dropout_kernel_bwd
(
const
int
N
,
const
float
ratio
,
const
__half
*
Xdata
,
__half
*
out
,
uint8_t
*
mask
,
std
::
pair
<
uint64_t
,
uint64_t
>
seed
)
{
const
float
scale
=
1.
/
(
1.
-
ratio
);
#ifdef __STOCHASTIC_MODE__
const
__half2
h_scale
=
__float2half2_rn
(
scale
);
const
float2
*
x_cast
=
reinterpret_cast
<
const
float2
*>
(
Xdata
);
float2
*
out_cast
=
reinterpret_cast
<
float2
*>
(
out
);
uint32_t
*
mask_cast
=
reinterpret_cast
<
uint32_t
*>
(
mask
);
CUDA_1D_KERNEL_LOOP
(
j
,
N
/
unroll_factor
)
{
float2
x_f
=
x_cast
[
j
];
__half2
*
x_h
=
reinterpret_cast
<
__half2
*>
(
&
x_f
);
uint32_t
m_32
=
mask_cast
[
j
];
uint8_t
*
m
=
(
uint8_t
*
)
&
m_32
;
__half2
mask_h
[
2
];
float2
mask_f
[
2
];
float
*
mask_f_data
=
&
mask_f
[
0
].
x
;
#pragma unroll
for
(
int
i
=
0
;
i
<
unroll_factor
;
i
++
)
mask_f_data
[
i
]
=
(
float
)(
m
[
i
]);
#pragma unroll
for
(
int
i
=
0
;
i
<
2
;
i
++
)
mask_h
[
i
]
=
__float22half2_rn
(
mask_f
[
i
]);
float2
result_f
;
__half2
*
result_h
=
reinterpret_cast
<
__half2
*>
(
&
result_f
);
result_h
[
0
]
=
x_h
[
0
]
*
h_scale
*
mask_h
[
0
];
result_h
[
1
]
=
x_h
[
1
]
*
h_scale
*
mask_h
[
1
];
out_cast
[
j
]
=
result_f
;
}
#else
const
__half
h_scale
=
__float2half
(
scale
);
const
__half
h_zero
=
__float2half
(
0.0
);
CUDA_1D_KERNEL_LOOP
(
j
,
N
/
unroll_factor
)
{
int
i
=
j
*
unroll_factor
;
const
__half2
*
vals_half
=
reinterpret_cast
<
const
__half2
*>
(
Xdata
+
i
);
uint8_t
*
m
=
mask
+
i
;
float2
vals_half_f
[
2
];
vals_half_f
[
0
]
=
__half22float2
(
vals_half
[
0
]);
vals_half_f
[
1
]
=
__half22float2
(
vals_half
[
1
]);
out
[
i
]
=
__float2half
(
vals_half_f
[
0
].
x
*
scale
*
m
[
0
]);
out
[
i
+
1
]
=
__float2half
(
vals_half_f
[
0
].
y
*
scale
*
m
[
1
]);
out
[
i
+
2
]
=
__float2half
(
vals_half_f
[
1
].
x
*
scale
*
m
[
2
]);
out
[
i
+
3
]
=
__float2half
(
vals_half_f
[
1
].
y
*
scale
*
m
[
3
]);
}
#endif
int
high_index
=
((((
N
/
unroll_factor
)
-
1
)
/
blockDim
.
x
+
1
)
*
(
unroll_factor
*
blockDim
.
x
))
+
threadIdx
.
x
;
if
(
N
>
high_index
)
{
for
(
int
i
=
high_index
;
i
<
N
;
i
++
)
{
out
[
i
]
=
__float2half
((
float
)
Xdata
[
i
]
*
scale
*
mask
[
i
]);
}
}
}
template
<
typename
T
>
void
launch_dropout
(
T
*
out
,
const
T
*
vals
,
uint8_t
*
mask
,
int
total_count
,
int
dim
,
float
ratio
,
cudaStream_t
stream
,
bool
bwd
)
{
assert
(
unroll_factor
==
4
);
dim3
grid_dim
=
DS_GET_BLOCKS
(
total_count
/
unroll_factor
);
dim3
block_dim
=
DS_CUDA_NUM_THREADS
;
if
(
dim
>
512
)
{
block_dim
.
x
>>=
1
;
grid_dim
.
x
<<=
1
;
}
uint64_t
inc
=
total_count
/
grid_dim
.
x
/
block_dim
.
x
;
std
::
pair
<
uint64_t
,
uint64_t
>
seed
=
Context
::
Instance
().
IncrementOffset
(
inc
);
if
(
bwd
)
dropout_kernel_bwd
<<<
grid_dim
,
block_dim
,
0
,
stream
>>>
(
total_count
,
ratio
,
vals
,
out
,
mask
,
seed
);
else
dropout_kernel
<<<
grid_dim
,
block_dim
,
0
,
stream
>>>
(
total_count
,
ratio
,
out
,
vals
,
mask
,
seed
);
}
template
void
launch_dropout
(
float
*
out
,
const
float
*
vals
,
uint8_t
*
mask
,
int
total_count
,
int
dim
,
float
ratio
,
cudaStream_t
stream
,
bool
);
template
void
launch_dropout
(
__half
*
out
,
const
__half
*
vals
,
uint8_t
*
mask
,
int
total_count
,
int
dim
,
float
ratio
,
cudaStream_t
stream
,
bool
);
__global__
void
dropout_grad_kernel
(
const
int
N
,
const
float
scale
,
float
*
Xdata
,
uint8_t
*
mask
)
{
CUDA_1D_KERNEL_LOOP
(
i
,
N
)
{
Xdata
[
i
]
*=
scale
*
mask
[
i
];
}
}
__global__
void
dropout_grad_kernel
(
const
int
N
,
const
float
scale
,
__half
*
Xdata
,
uint8_t
*
mask
)
{
const
__half2
h_scale
=
__float2half2_rn
(
scale
);
float2
*
x_cast
=
reinterpret_cast
<
float2
*>
(
Xdata
);
uint32_t
*
mask_cast
=
reinterpret_cast
<
uint32_t
*>
(
mask
);
CUDA_1D_KERNEL_LOOP
(
j
,
N
/
unroll_factor
)
{
float2
x_data
=
x_cast
[
j
];
uint32_t
m_32
=
mask_cast
[
j
];
uint8_t
*
m
=
(
uint8_t
*
)
&
m_32
;
float2
result_f
;
__half2
*
result_h
=
reinterpret_cast
<
__half2
*>
(
&
result_f
);
#ifdef __STOCHASTIC_MODE__
__half2
*
x_data_h
=
reinterpret_cast
<
__half2
*>
(
&
x_data
);
__half2
mask_h
[
2
];
float2
mask_f
[
2
];
float
*
mask_f_data
=
&
mask_f
[
0
].
x
;
#pragma unroll
for
(
int
i
=
0
;
i
<
unroll_factor
;
i
++
)
*
(
mask_f_data
++
)
=
(
float
)(
m
[
i
]);
mask_h
[
0
]
=
__float22half2_rn
(
mask_f
[
0
]);
mask_h
[
1
]
=
__float22half2_rn
(
mask_f
[
1
]);
result_h
[
0
]
=
x_data_h
[
0
]
*
h_scale
*
mask_h
[
0
];
result_h
[
1
]
=
x_data_h
[
1
]
*
h_scale
*
mask_h
[
1
];
#else
__half
*
x_data_h
=
reinterpret_cast
<
__half
*>
(
&
x_data
);
float2
result
[
2
];
result
[
0
].
x
=
(
float
)
x_data_h
[
0
]
*
scale
*
m
[
0
];
result
[
0
].
y
=
(
float
)
x_data_h
[
1
]
*
scale
*
m
[
1
];
result
[
1
].
x
=
(
float
)
x_data_h
[
2
]
*
scale
*
m
[
2
];
result
[
1
].
y
=
(
float
)
x_data_h
[
3
]
*
scale
*
m
[
3
];
result_h
[
0
]
=
__float22half2_rn
(
result
[
0
]);
result_h
[
1
]
=
__float22half2_rn
(
result
[
1
]);
#endif
x_cast
[
j
]
=
result_f
;
}
int
high_index
=
((((
N
/
unroll_factor
)
-
1
)
/
blockDim
.
x
+
1
)
*
(
unroll_factor
*
blockDim
.
x
))
+
threadIdx
.
x
;
if
(
N
>
high_index
)
{
for
(
int
i
=
high_index
;
i
<
N
;
i
++
)
{
Xdata
[
i
]
=
__float2half
((
float
)
Xdata
[
i
]
*
scale
*
mask
[
i
]);
}
}
}
template
<
typename
T
>
void
launch_dropout_grad
(
T
*
vals
,
uint8_t
*
mask
,
int
total_count
,
float
ratio
,
cudaStream_t
stream
)
{
assert
(
unroll_factor
==
4
);
const
float
scale
=
1.
/
(
1.
-
ratio
);
dropout_grad_kernel
<<<
DS_GET_BLOCKS
(
total_count
/
unroll_factor
),
DS_CUDA_NUM_THREADS
,
0
,
stream
>>>
(
total_count
,
scale
,
vals
,
mask
);
}
template
void
launch_dropout_grad
(
float
*
vals
,
uint8_t
*
mask
,
int
total_count
,
float
ratio
,
cudaStream_t
stream
);
template
void
launch_dropout_grad
(
__half
*
vals
,
uint8_t
*
mask
,
int
total_count
,
float
ratio
,
cudaStream_t
stream
);
__global__
void
dropout_grad_kernel
(
const
int
N
,
const
float
scale
,
const
float
*
Xdata
,
float
*
out
,
uint8_t
*
mask
)
{
CUDA_1D_KERNEL_LOOP
(
i
,
N
)
{
out
[
i
]
=
Xdata
[
i
]
*
scale
*
mask
[
i
];
}
}
__global__
void
dropout_grad_kernel
(
const
int
N
,
const
float
scale
,
const
__half
*
Xdata
,
__half
*
out
,
uint8_t
*
mask
)
{
const
float2
*
x_cast
=
reinterpret_cast
<
const
float2
*>
(
Xdata
);
float2
*
out_cast
=
reinterpret_cast
<
float2
*>
(
out
);
const
uint32_t
*
mask_cast
=
reinterpret_cast
<
const
uint32_t
*>
(
mask
);
float2
result_f
;
__half2
*
result_h
=
reinterpret_cast
<
__half2
*>
(
&
result_f
);
CUDA_1D_KERNEL_LOOP
(
j
,
N
/
unroll_factor
)
{
float2
x_data
=
x_cast
[
j
];
uint32_t
m_32
=
mask_cast
[
j
];
uint8_t
*
m
=
(
uint8_t
*
)
&
m_32
;
__half
*
x_data_h
=
reinterpret_cast
<
__half
*>
(
&
x_data
);
float2
result
[
2
];
result
[
0
].
x
=
(
float
)
x_data_h
[
0
]
*
scale
*
m
[
0
];
result
[
0
].
y
=
(
float
)
x_data_h
[
1
]
*
scale
*
m
[
1
];
result
[
1
].
x
=
(
float
)
x_data_h
[
2
]
*
scale
*
m
[
2
];
result
[
1
].
y
=
(
float
)
x_data_h
[
3
]
*
scale
*
m
[
3
];
result_h
[
0
]
=
__float22half2_rn
(
result
[
0
]);
result_h
[
1
]
=
__float22half2_rn
(
result
[
1
]);
out_cast
[
j
]
=
result_f
;
}
int
high_index
=
((((
N
/
unroll_factor
)
-
1
)
/
blockDim
.
x
+
1
)
*
(
unroll_factor
*
blockDim
.
x
))
+
threadIdx
.
x
;
if
(
N
>
high_index
)
{
for
(
int
i
=
high_index
;
i
<
N
;
i
++
)
{
out
[
i
]
=
__float2half
((
float
)
Xdata
[
i
]
*
scale
*
mask
[
i
]);
}
}
}
template
<
typename
T
>
void
launch_dropout_grad
(
T
*
vals_out
,
const
T
*
vals
,
uint8_t
*
mask
,
int
total_count
,
float
ratio
,
cudaStream_t
stream
)
{
assert
(
unroll_factor
==
4
);
const
float
scale
=
1.
/
(
1.
-
ratio
);
dropout_grad_kernel
<<<
DS_GET_BLOCKS
(
total_count
/
unroll_factor
),
DS_CUDA_NUM_THREADS
,
0
,
stream
>>>
(
total_count
,
scale
,
vals
,
vals_out
,
mask
);
}
template
void
launch_dropout_grad
(
float
*
,
const
float
*
vals
,
uint8_t
*
mask
,
int
total_count
,
float
ratio
,
cudaStream_t
stream
);
template
void
launch_dropout_grad
(
__half
*
,
const
__half
*
vals
,
uint8_t
*
mask
,
int
total_count
,
float
ratio
,
cudaStream_t
stream
);
__global__
void
dropout_kernel
(
const
int
N
,
const
int
dim
,
const
float
ratio
,
const
float
*
bias
,
float
*
Xdata
,
uint8_t
*
mask
,
std
::
pair
<
uint64_t
,
uint64_t
>
seed
)
{
const
float
scale
=
1.
/
(
1.
-
ratio
);
int
idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
int
tid
=
threadIdx
.
x
%
(
dim
/
unroll_factor
);
curandStatePhilox4_32_10_t
state
;
curand_init
(
seed
.
first
,
idx
,
seed
.
second
,
&
state
);
float4
*
Xdata_cast
=
reinterpret_cast
<
float4
*>
(
Xdata
);
uint32_t
*
mask_32
=
reinterpret_cast
<
uint32_t
*>
(
mask
);
const
float4
*
bias_cast
=
reinterpret_cast
<
const
float4
*>
(
bias
);
CUDA_1D_KERNEL_LOOP
(
j
,
N
)
{
float4
rand
=
curand_uniform4
(
&
state
);
uint32_t
m_32
;
uint8_t
*
m
=
(
uint8_t
*
)
&
m_32
;
m
[
0
]
=
(
uint8_t
)(
rand
.
x
>
ratio
);
m
[
1
]
=
(
uint8_t
)(
rand
.
y
>
ratio
);
m
[
2
]
=
(
uint8_t
)(
rand
.
z
>
ratio
);
m
[
3
]
=
(
uint8_t
)(
rand
.
w
>
ratio
);
float4
x_data
=
Xdata_cast
[
j
];
float4
b_data
=
bias_cast
[
j
%
(
dim
/
unroll_factor
)];
x_data
.
x
+=
b_data
.
x
;
x_data
.
y
+=
b_data
.
y
;
x_data
.
z
+=
b_data
.
z
;
x_data
.
w
+=
b_data
.
w
;
x_data
.
x
=
x_data
.
x
*
scale
*
m
[
0
];
x_data
.
y
=
x_data
.
y
*
scale
*
m
[
1
];
x_data
.
z
=
x_data
.
z
*
scale
*
m
[
2
];
x_data
.
w
=
x_data
.
w
*
scale
*
m
[
3
];
mask_32
[
j
]
=
m_32
;
Xdata_cast
[
j
]
=
x_data
;
}
int
high_index
=
((((
N
/
unroll_factor
)
-
1
)
/
blockDim
.
x
+
1
)
*
(
unroll_factor
*
blockDim
.
x
))
+
threadIdx
.
x
;
if
(
N
>
high_index
)
{
float4
rand
=
curand_uniform4
(
&
state
);
float
*
rand_data
=
&
(
rand
.
x
);
int
k
=
0
;
for
(
int
i
=
high_index
;
i
<
N
;
i
++
)
{
float
x_data
=
Xdata
[
i
]
+
bias
[
i
%
dim
];
uint8_t
m
=
(
uint8_t
)(
rand_data
[
k
++
]
>
ratio
);
Xdata
[
i
]
=
x_data
*
scale
*
m
;
mask
[
i
]
=
m
;
}
}
}
__global__
void
dropout_kernel
(
const
int
N
,
const
int
dim
,
const
float
ratio
,
const
__half
*
bias
,
__half
*
Xdata
,
uint8_t
*
mask
,
std
::
pair
<
uint64_t
,
uint64_t
>
seed
)
{
const
float
scale
=
1.
/
(
1.
-
ratio
);
int
idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
int
tid
=
threadIdx
.
x
%
(
dim
/
unroll_factor
);
curandStatePhilox4_32_10_t
state
;
curand_init
(
seed
.
first
,
idx
,
seed
.
second
,
&
state
);
float2
*
Xdata_cast
=
reinterpret_cast
<
float2
*>
(
Xdata
);
uint32_t
*
mask_32
=
reinterpret_cast
<
uint32_t
*>
(
mask
);
const
float2
*
bias_cast
=
reinterpret_cast
<
const
float2
*>
(
bias
);
CUDA_1D_KERNEL_LOOP
(
j
,
N
)
{
float4
rand
=
curand_uniform4
(
&
state
);
float2
data_f
;
__half2
*
data_h
=
reinterpret_cast
<
__half2
*>
(
&
data_f
);
float2
bias_f
;
__half2
*
bias_h
=
reinterpret_cast
<
__half2
*>
(
&
bias_f
);
data_f
=
Xdata_cast
[
j
];
bias_f
=
bias_cast
[
j
%
(
dim
/
unroll_factor
)];
float2
data_h_0
=
__half22float2
(
data_h
[
0
]);
float2
data_h_1
=
__half22float2
(
data_h
[
1
]);
float2
bias_h_0
=
__half22float2
(
bias_h
[
0
]);
float2
bias_h_1
=
__half22float2
(
bias_h
[
1
]);
data_h_0
.
x
+=
bias_h_0
.
x
;
data_h_0
.
y
+=
bias_h_0
.
y
;
data_h_1
.
x
+=
bias_h_1
.
x
;
data_h_1
.
y
+=
bias_h_1
.
y
;
uint32_t
m_32
;
uint8_t
*
m
=
(
uint8_t
*
)
&
m_32
;
m
[
0
]
=
(
uint8_t
)(
rand
.
x
>
ratio
);
m
[
1
]
=
(
uint8_t
)(
rand
.
y
>
ratio
);
m
[
2
]
=
(
uint8_t
)(
rand
.
z
>
ratio
);
m
[
3
]
=
(
uint8_t
)(
rand
.
w
>
ratio
);
data_h_0
.
x
=
__float2half
(
data_h_0
.
x
*
scale
*
m
[
0
]);
data_h_0
.
y
=
__float2half
(
data_h_0
.
y
*
scale
*
m
[
1
]);
data_h_1
.
x
=
__float2half
(
data_h_1
.
x
*
scale
*
m
[
2
]);
data_h_1
.
y
=
__float2half
(
data_h_1
.
y
*
scale
*
m
[
3
]);
float2
result_f
;
__half2
*
result_h
=
reinterpret_cast
<
__half2
*>
(
&
result_f
);
result_h
[
0
]
=
__float22half2_rn
(
data_h_0
);
result_h
[
1
]
=
__float22half2_rn
(
data_h_1
);
Xdata_cast
[
j
]
=
result_f
;
mask_32
[
j
]
=
m_32
;
}
int
high_index
=
((((
N
/
unroll_factor
)
-
1
)
/
blockDim
.
x
+
1
)
*
(
unroll_factor
*
blockDim
.
x
))
+
threadIdx
.
x
;
if
(
N
>
high_index
)
{
float4
rand
=
curand_uniform4
(
&
state
);
float
*
rand_data
=
&
(
rand
.
x
);
int
k
=
0
;
for
(
int
i
=
high_index
;
i
<
N
;
i
++
)
{
float
x_data
=
(
float
)
Xdata
[
i
]
+
(
float
)
bias
[
i
%
dim
];
uint8_t
m
=
(
uint8_t
)(
rand_data
[
k
++
]
>
ratio
);
Xdata
[
i
]
=
__float2half
(
x_data
*
scale
*
m
);
mask
[
i
]
=
m
;
}
}
}
template
<
typename
T
>
void
launch_dropout
(
T
*
out
,
const
T
*
bias
,
uint8_t
*
mask
,
int
batch
,
int
dim
,
float
ratio
,
cudaStream_t
stream
)
{
assert
(
unroll_factor
==
4
);
int
total_count
=
batch
*
dim
/
unroll_factor
;
dim3
grid_dim
=
DS_GET_BLOCKS
(
total_count
);
dim3
block_dim
=
DS_CUDA_NUM_THREADS
;
uint64_t
inc
=
(
batch
*
dim
)
/
grid_dim
.
x
/
block_dim
.
x
;
std
::
pair
<
uint64_t
,
uint64_t
>
seed
=
Context
::
Instance
().
IncrementOffset
(
inc
);
dropout_kernel
<<<
grid_dim
,
block_dim
,
0
,
stream
>>>
(
total_count
,
dim
,
ratio
,
bias
,
out
,
mask
,
seed
);
}
template
void
launch_dropout
(
float
*
,
const
float
*
bias
,
uint8_t
*
mask
,
int
batch
,
int
dim
,
float
ratio
,
cudaStream_t
stream
);
template
void
launch_dropout
(
__half
*
,
const
__half
*
bias
,
uint8_t
*
mask
,
int
batch
,
int
dim
,
float
ratio
,
cudaStream_t
stream
);
__global__
void
dropout_kernel
(
const
int
N
,
const
int
dim
,
const
float
ratio
,
const
float
*
input
,
const
float
*
residual
,
const
float
*
bias
,
float
*
out
,
uint8_t
*
mask
,
std
::
pair
<
uint64_t
,
uint64_t
>
seed
)
{
const
float
scale
=
1.
/
(
1.
-
ratio
);
int
idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
int
tid
=
threadIdx
.
x
%
(
dim
/
unroll_factor
);
curandStatePhilox4_32_10_t
state
;
curand_init
(
seed
.
first
,
idx
,
seed
.
second
,
&
state
);
float4
*
out_cast
=
reinterpret_cast
<
float4
*>
(
out
);
uint32_t
*
mask_32
=
reinterpret_cast
<
uint32_t
*>
(
mask
);
const
float4
*
bias_cast
=
reinterpret_cast
<
const
float4
*>
(
bias
);
const
float4
*
residual_cast
=
reinterpret_cast
<
const
float4
*>
(
residual
);
const
float4
*
input_cast
=
reinterpret_cast
<
const
float4
*>
(
input
);
CUDA_1D_KERNEL_LOOP
(
j
,
N
)
{
float4
rand
=
curand_uniform4
(
&
state
);
uint32_t
m_32
;
uint8_t
*
m
=
(
uint8_t
*
)
&
m_32
;
m
[
0
]
=
(
uint8_t
)(
rand
.
x
>
ratio
);
m
[
1
]
=
(
uint8_t
)(
rand
.
y
>
ratio
);
m
[
2
]
=
(
uint8_t
)(
rand
.
z
>
ratio
);
m
[
3
]
=
(
uint8_t
)(
rand
.
w
>
ratio
);
float4
out_data
;
float4
b_data
=
bias_cast
[
j
%
(
dim
/
unroll_factor
)];
float4
res_data
=
residual_cast
[
j
];
float4
inp_data
=
input_cast
[
j
];
out_data
.
x
=
(
b_data
.
x
+
inp_data
.
x
);
out_data
.
y
=
(
b_data
.
y
+
inp_data
.
y
);
out_data
.
z
=
(
b_data
.
z
+
inp_data
.
z
);
out_data
.
w
=
(
b_data
.
w
+
inp_data
.
w
);
out_data
.
x
=
out_data
.
x
*
scale
*
m
[
0
];
out_data
.
y
=
out_data
.
y
*
scale
*
m
[
1
];
out_data
.
z
=
out_data
.
z
*
scale
*
m
[
2
];
out_data
.
w
=
out_data
.
w
*
scale
*
m
[
3
];
out_data
.
x
+=
res_data
.
x
;
out_data
.
y
+=
res_data
.
y
;
out_data
.
z
+=
res_data
.
z
;
out_data
.
w
+=
res_data
.
w
;
mask_32
[
j
]
=
m_32
;
out_cast
[
j
]
=
out_data
;
}
int
high_index
=
((((
N
/
unroll_factor
)
-
1
)
/
blockDim
.
x
+
1
)
*
(
unroll_factor
*
blockDim
.
x
))
+
threadIdx
.
x
;
if
(
N
>
high_index
)
{
float4
rand
=
curand_uniform4
(
&
state
);
float
*
rand_data
=
&
(
rand
.
x
);
int
k
=
0
;
for
(
int
i
=
high_index
;
i
<
N
;
i
++
)
{
float
x_data
=
input
[
i
]
+
bias
[
i
%
dim
];
uint8_t
m
=
(
uint8_t
)(
rand_data
[
k
++
]
>
ratio
);
x_data
=
x_data
*
scale
*
m
;
x_data
+=
residual
[
i
];
out
[
i
]
=
x_data
;
mask
[
i
]
=
m
;
}
}
}
__global__
void
dropout_kernel
(
const
int
N
,
const
int
dim
,
const
float
ratio
,
const
__half
*
input
,
const
__half
*
residual
,
const
__half
*
bias
,
__half
*
out
,
uint8_t
*
mask
,
std
::
pair
<
uint64_t
,
uint64_t
>
seed
)
{
const
float
scale
=
1.
/
(
1.
-
ratio
);
int
idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
int
tid
=
threadIdx
.
x
%
(
dim
/
unroll_factor
);
curandStatePhilox4_32_10_t
state
;
curand_init
(
seed
.
first
,
idx
,
seed
.
second
,
&
state
);
float2
*
out_cast
=
reinterpret_cast
<
float2
*>
(
out
);
uint32_t
*
mask_32
=
reinterpret_cast
<
uint32_t
*>
(
mask
);
const
float2
*
bias_cast
=
reinterpret_cast
<
const
float2
*>
(
bias
);
const
float2
*
residual_cast
=
reinterpret_cast
<
const
float2
*>
(
residual
);
const
float2
*
input_cast
=
reinterpret_cast
<
const
float2
*>
(
input
);
CUDA_1D_KERNEL_LOOP
(
j
,
N
)
{
float4
rand
=
curand_uniform4
(
&
state
);
float2
data_f
;
__half2
*
data_h
=
reinterpret_cast
<
__half2
*>
(
&
data_f
);
float2
bias_f
;
__half2
*
bias_h
=
reinterpret_cast
<
__half2
*>
(
&
bias_f
);
float2
residual_f
;
__half2
*
residual_h
=
reinterpret_cast
<
__half2
*>
(
&
residual_f
);
float2
input_f
;
__half2
*
input_h
=
reinterpret_cast
<
__half2
*>
(
&
input_f
);
bias_f
=
bias_cast
[
j
%
(
dim
/
unroll_factor
)];
residual_f
=
residual_cast
[
j
];
input_f
=
input_cast
[
j
];
float2
data_h_0
=
__half22float2
(
data_h
[
0
]);
float2
data_h_1
=
__half22float2
(
data_h
[
1
]);
float2
bias_h_0
=
__half22float2
(
bias_h
[
0
]);
float2
bias_h_1
=
__half22float2
(
bias_h
[
1
]);
float2
residual_h_0
=
__half22float2
(
residual_h
[
0
]);
float2
residual_h_1
=
__half22float2
(
residual_h
[
1
]);
float2
input_h_0
=
__half22float2
(
input_h
[
0
]);
float2
input_h_1
=
__half22float2
(
input_h
[
1
]);
data_h_0
.
x
=
(
bias_h_0
.
x
+
input_h_0
.
x
);
data_h_0
.
y
=
(
bias_h_0
.
y
+
input_h_0
.
y
);
data_h_1
.
x
=
(
bias_h_1
.
x
+
input_h_1
.
x
);
data_h_1
.
y
=
(
bias_h_1
.
y
+
input_h_1
.
y
);
uint32_t
m_32
;
uint8_t
*
m
=
(
uint8_t
*
)
&
m_32
;
m
[
0
]
=
(
uint8_t
)(
rand
.
x
>
ratio
);
m
[
1
]
=
(
uint8_t
)(
rand
.
y
>
ratio
);
m
[
2
]
=
(
uint8_t
)(
rand
.
z
>
ratio
);
m
[
3
]
=
(
uint8_t
)(
rand
.
w
>
ratio
);
data_h_0
.
x
=
__float2half
(
data_h_0
.
x
*
scale
*
m
[
0
]);
data_h_0
.
y
=
__float2half
(
data_h_0
.
y
*
scale
*
m
[
1
]);
data_h_1
.
x
=
__float2half
(
data_h_1
.
x
*
scale
*
m
[
2
]);
data_h_1
.
y
=
__float2half
(
data_h_1
.
y
*
scale
*
m
[
3
]);
data_h_0
.
x
+=
residual_h_0
.
x
;
data_h_0
.
y
+=
residual_h_0
.
y
;
data_h_1
.
x
+=
residual_h_1
.
x
;
data_h_1
.
y
+=
residual_h_1
.
y
;
float2
result_f
;
__half2
*
result_h
=
reinterpret_cast
<
__half2
*>
(
&
result_f
);
result_h
[
0
]
=
__float22half2_rn
(
data_h_0
);
result_h
[
1
]
=
__float22half2_rn
(
data_h_1
);
out_cast
[
j
]
=
result_f
;
mask_32
[
j
]
=
m_32
;
}
int
high_index
=
((((
N
/
unroll_factor
)
-
1
)
/
blockDim
.
x
+
1
)
*
(
unroll_factor
*
blockDim
.
x
))
+
threadIdx
.
x
;
if
(
N
>
high_index
)
{
float4
rand
=
curand_uniform4
(
&
state
);
float
*
rand_data
=
&
(
rand
.
x
);
int
k
=
0
;
for
(
int
i
=
high_index
;
i
<
N
;
i
++
)
{
float
x_data
=
(
float
)
input
[
i
]
+
(
float
)
bias
[
i
%
dim
];
uint8_t
m
=
(
uint8_t
)(
rand_data
[
k
++
]
>
ratio
);
x_data
=
x_data
*
scale
*
m
;
x_data
+=
(
float
)
residual
[
i
];
out
[
i
]
=
__float2half
(
x_data
);
mask
[
i
]
=
m
;
}
}
}
template
<
typename
T
>
void
launch_dropout
(
T
*
out
,
const
T
*
input
,
const
T
*
residual
,
const
T
*
bias
,
uint8_t
*
mask
,
int
batch
,
int
dim
,
float
ratio
,
cudaStream_t
stream
)
{
assert
(
unroll_factor
==
4
);
int
total_count
=
batch
*
dim
/
unroll_factor
;
dim3
grid_dim
=
DS_GET_BLOCKS
(
total_count
);
dim3
block_dim
=
DS_CUDA_NUM_THREADS
;
uint64_t
inc
=
(
batch
*
dim
)
/
grid_dim
.
x
/
block_dim
.
x
;
std
::
pair
<
uint64_t
,
uint64_t
>
seed
=
Context
::
Instance
().
IncrementOffset
(
inc
);
dropout_kernel
<<<
grid_dim
,
block_dim
,
0
,
stream
>>>
(
total_count
,
dim
,
ratio
,
input
,
residual
,
bias
,
out
,
mask
,
seed
);
}
template
void
launch_dropout
(
float
*
,
const
float
*
,
const
float
*
residual
,
const
float
*
bias
,
uint8_t
*
mask
,
int
batch
,
int
dim
,
float
ratio
,
cudaStream_t
stream
);
template
void
launch_dropout
(
__half
*
,
const
__half
*
,
const
__half
*
residual
,
const
__half
*
bias
,
uint8_t
*
mask
,
int
batch
,
int
dim
,
float
ratio
,
cudaStream_t
stream
);
deepspeed/ops/csrc/transformer/ds_transformer_cuda.cpp
0 → 100644
View file @
eadbbe09
#include <torch/extension.h>
#include <cublas_v2.h>
#include <cuda_fp16.h>
#include <cuda_runtime.h>
#include <type_traits>
#include <unordered_map>
#include <vector>
#include "Timer.h"
#include "context.h"
#include "cublas_wrappers.h"
#include "custom_cuda_layers.h"
#include "ds_transformer_cuda.h"
static
std
::
unordered_map
<
int
,
std
::
shared_ptr
<
void
>>
s_transformer_layers
;
const
int
init_seq_length
=
128
;
// C++ interface
template
<
typename
T
>
size_t
get_workspace_size
(
int
maxBatchSize
,
int
seq_len
,
int
hidden_size
,
int
intermediate_size
,
int
heads
,
bool
training
,
bool
gelu_checkpoint
)
{
size_t
workSpacesize
=
4
*
(
size_t
(
maxBatchSize
)
*
seq_len
*
hidden_size
);
if
(
training
)
{
workSpacesize
+=
((
std
::
max
)((
size_t
(
maxBatchSize
)
*
seq_len
*
intermediate_size
),
2
*
(
size_t
(
maxBatchSize
)
*
heads
*
seq_len
*
seq_len
)));
if
(
gelu_checkpoint
)
workSpacesize
+=
2
*
(
size_t
(
maxBatchSize
)
*
seq_len
*
intermediate_size
);
}
return
workSpacesize
;
// * sizeof(T);
}
// NOTE: AT_ASSERT has become AT_CHECK on master after 0.4.
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) \
CHECK_CUDA(x); \
CHECK_CONTIGUOUS(x)
template
<
typename
T
>
BertTransformerLayer
<
T
>::
BertTransformerLayer
(
int
layer_id
,
int
batch_size
,
int
hidden_size
,
int
num_heads
,
int
intermediate_size
,
int
seq_length
,
float
attn_prob_dropout_ratio
,
float
hidden_output_dropout_ratio
,
float
layer_norm_eps
,
bool
pre_or_postLayerNorm
,
const
std
::
vector
<
std
::
array
<
int
,
3
>>&
gemm_algos
,
bool
attn_dropout_checkpoint
,
bool
normalize_invertible
,
bool
gelu_checkpoint
,
bool
stochastic_mode
)
:
_layer_id
(
layer_id
),
_batch_size
(
batch_size
),
_hidden_size
(
hidden_size
),
_heads
(
num_heads
),
_intermediate_size
(
intermediate_size
),
_seq_length
(
seq_length
),
_training
(
true
),
_pre_or_postLayerNorm
(
pre_or_postLayerNorm
),
_attn_dropout_checkpoint
(
attn_dropout_checkpoint
),
_normalize_invertible
(
normalize_invertible
),
_gelu_checkpoint
(
gelu_checkpoint
),
_stochastic_mode
(
stochastic_mode
),
_stream
(
Context
::
Instance
().
GetCurrentStream
()),
_cublasHandle
(
Context
::
Instance
().
GetCublasHandle
()),
_qkv_linear
(
typename
FeedForward
<
T
>::
Config
(
batch_size
*
seq_length
,
3
*
hidden_size
,
hidden_size
,
gemm_algos
[
0
])),
_attn_out_linear
(
typename
FeedForward
<
T
>::
Config
(
batch_size
*
seq_length
,
hidden_size
,
hidden_size
,
gemm_algos
[
0
])),
_attn_layer_norm
(
typename
Normalize_Layer
<
T
>::
Config
(
batch_size
,
seq_length
,
hidden_size
,
layer_norm_eps
,
true
,
!
normalize_invertible
)),
_layer_norm
(
typename
Normalize_Layer
<
T
>::
Config
(
batch_size
,
seq_length
,
hidden_size
,
layer_norm_eps
,
true
,
!
normalize_invertible
)),
_ff1
(
typename
FeedForward
<
T
>::
Config
(
batch_size
*
seq_length
,
_intermediate_size
,
hidden_size
,
gemm_algos
[
1
])),
_ff2
(
typename
FeedForward
<
T
>::
Config
(
batch_size
*
seq_length
,
hidden_size
,
_intermediate_size
,
gemm_algos
[
2
])),
_softmax
(
typename
Softmax
<
T
>::
Config
(
batch_size
,
num_heads
,
seq_length
)),
_gelu
(
typename
Gelu
<
T
>::
Config
(
_intermediate_size
)),
_attn_prob_dropout
(
typename
Dropout
<
T
>::
Config
(
attn_prob_dropout_ratio
,
_seq_length
)),
_attn_output_dropout
(
typename
Dropout
<
T
>::
Config
(
hidden_output_dropout_ratio
,
_hidden_size
)),
_layer_output_dropout
(
typename
Dropout
<
T
>::
Config
(
hidden_output_dropout_ratio
,
_hidden_size
)),
_attn_scores
(
typename
StridedBatchGemm
<
T
>::
Config
(
_batch_size
*
_heads
,
_seq_length
,
_seq_length
,
_hidden_size
/
_heads
,
(
T
(
1.0
)
/
T
(
sqrt
(
_hidden_size
/
_heads
))),
T
(
0.0
),
CUBLAS_OP_T
,
CUBLAS_OP_N
,
gemm_algos
[
3
])),
_attn_context
(
typename
StridedBatchGemm
<
T
>::
Config
(
_batch_size
*
_heads
,
_hidden_size
/
_heads
,
_seq_length
,
_seq_length
,
T
(
1.0
),
T
(
0.0
),
CUBLAS_OP_N
,
CUBLAS_OP_N
,
gemm_algos
[
4
]))
{
assert
(
_hidden_size
%
_heads
==
0
);
Initialize
();
}
template
<
typename
T
>
BertTransformerLayer
<
T
>::~
BertTransformerLayer
()
{
}
template
<
typename
T
>
void
BertTransformerLayer
<
T
>::
Initialize
()
{
if
(
std
::
is_same
<
T
,
__half
>::
value
)
cublasSetMathMode
(
_cublasHandle
,
CUBLAS_TENSOR_OP_MATH
);
}
template
<
typename
T
>
void
BertTransformerLayer
<
T
>::
Forward
(
int
bsz
,
const
T
*
input_ptr
,
const
T
*
input_mask_ptr
,
const
T
*
attn_qkvw_ptr
,
const
T
*
attn_qkvb_ptr
,
const
T
*
attn_ow_ptr
,
const
T
*
attn_ob_ptr
,
const
T
*
attn_nw_ptr
,
const
T
*
attn_nb_ptr
,
const
T
*
inter_w_ptr
,
const
T
*
inter_b_ptr
,
const
T
*
output_w_ptr
,
const
T
*
output_b_ptr
,
const
T
*
norm_w_ptr
,
const
T
*
norm_b_ptr
,
T
*
out_ptr
,
T
*
inp_norm_ptr
,
T
*
q_tf_ptr
,
T
*
k_tf_ptr
,
T
*
v_tf_ptr
,
T
*
soft_out_ptr
,
T
*
ctx_bufB_ptr
,
T
*
attn_o_inp_ptr
,
T
*
add_res_ptr
,
T
*
ff1_inp_ptr
,
T
*
gelu_inp_ptr
,
T
*
ff2_inp_ptr
)
{
cublasSetStream
(
_cublasHandle
,
_stream
);
if
(
!
_stochastic_mode
)
cudaStreamSynchronize
(
_stream
);
T
*
workspace
=
static_cast
<
T
*>
(
Context
::
Instance
().
GetWorkSpace
());
size_t
small_buf_size
=
bsz
*
_seq_length
*
_hidden_size
;
T
*
buf_0
=
workspace
;
T
*
buf_1
=
buf_0
+
small_buf_size
;
T
*
buf_2
=
buf_1
;
if
(
_normalize_invertible
)
{
add_res_ptr
=
buf_1
+
3
*
small_buf_size
;
buf_2
=
add_res_ptr
;
}
if
(
_gelu_checkpoint
)
buf_2
+=
small_buf_size
;
if
(
_attn_dropout_checkpoint
)
ctx_bufB_ptr
=
(
_gelu_checkpoint
?
(
buf_2
+
(
_intermediate_size
/
_hidden_size
)
*
small_buf_size
)
:
(
buf_1
+
4
*
small_buf_size
));
int
bsz_seq
=
bsz
*
_seq_length
;
if
(
_pre_or_postLayerNorm
)
{
if
(
_layer_norm
.
UseMean
())
_layer_norm
.
ForwardCheckpoint
(
bsz_seq
,
inp_norm_ptr
,
input_ptr
,
norm_w_ptr
,
norm_b_ptr
,
_stream
,
true
);
else
_layer_norm
.
Forward
(
bsz_seq
,
inp_norm_ptr
,
input_ptr
,
norm_w_ptr
,
norm_b_ptr
,
_stream
,
true
);
}
if
(
_pre_or_postLayerNorm
)
_qkv_linear
.
Forward
(
bsz_seq
,
inp_norm_ptr
,
attn_qkvw_ptr
,
buf_0
,
_cublasHandle
);
else
_qkv_linear
.
Forward
(
bsz_seq
,
input_ptr
,
attn_qkvw_ptr
,
buf_0
,
_cublasHandle
);
launch_bias_add_transform_0213
<
T
>
(
q_tf_ptr
,
buf_0
,
attn_qkvb_ptr
,
bsz
,
_seq_length
,
_hidden_size
,
_heads
,
_stream
,
3
);
int
bsz_heads
=
bsz
*
_heads
;
// attention scores
_attn_scores
.
Forward
(
bsz_heads
,
soft_out_ptr
,
k_tf_ptr
,
q_tf_ptr
,
_cublasHandle
);
// Softmax + Mask
_softmax
.
Forward
(
bsz
,
soft_out_ptr
,
input_mask_ptr
,
_stream
);
// attn prob dropout.
_attn_prob_dropout
.
Forward
(
bsz_heads
*
_seq_length
,
ctx_bufB_ptr
,
soft_out_ptr
,
_stream
);
// attention context
_attn_context
.
Forward
(
bsz_heads
,
buf_1
,
v_tf_ptr
,
ctx_bufB_ptr
,
_cublasHandle
);
launch_transform4d_0213
<
T
>
(
attn_o_inp_ptr
,
buf_1
,
bsz
,
_heads
,
_seq_length
,
_hidden_size
,
_stream
,
1
);
if
(
_pre_or_postLayerNorm
)
_attn_out_linear
.
Forward
(
bsz_seq
,
attn_o_inp_ptr
,
attn_ow_ptr
,
buf_1
,
_cublasHandle
);
else
_attn_out_linear
.
Forward
(
bsz_seq
,
attn_o_inp_ptr
,
attn_ow_ptr
,
ff1_inp_ptr
,
_cublasHandle
);
// attn output dropout.
if
(
_pre_or_postLayerNorm
)
_attn_output_dropout
.
ForwardWithBias
(
bsz_seq
,
add_res_ptr
,
buf_1
,
input_ptr
,
attn_ob_ptr
,
_stream
);
else
_attn_output_dropout
.
ForwardWithBias
(
bsz_seq
,
add_res_ptr
,
ff1_inp_ptr
,
input_ptr
,
attn_ob_ptr
,
_stream
);
if
(
_pre_or_postLayerNorm
)
{
if
(
_attn_layer_norm
.
UseMean
())
_attn_layer_norm
.
ForwardCheckpoint
(
bsz_seq
,
ff1_inp_ptr
,
add_res_ptr
,
attn_nw_ptr
,
attn_nb_ptr
,
_stream
,
true
);
else
_attn_layer_norm
.
Forward
(
bsz_seq
,
ff1_inp_ptr
,
add_res_ptr
,
attn_nw_ptr
,
attn_nb_ptr
,
_stream
,
true
);
}
else
{
if
(
_attn_layer_norm
.
UseMean
())
_attn_layer_norm
.
ForwardCheckpoint
(
bsz_seq
,
ff1_inp_ptr
,
add_res_ptr
,
attn_nw_ptr
,
attn_nb_ptr
,
_stream
,
true
);
else
_attn_layer_norm
.
Forward
(
bsz_seq
,
ff1_inp_ptr
,
add_res_ptr
,
attn_nw_ptr
,
attn_nb_ptr
,
_stream
,
true
);
}
_ff1
.
Forward
(
bsz_seq
,
ff1_inp_ptr
,
inter_w_ptr
,
(
_gelu_checkpoint
?
ff2_inp_ptr
:
gelu_inp_ptr
),
_cublasHandle
);
_gelu
.
ForwardWithBiasAdd
(
bsz_seq
,
(
_gelu_checkpoint
?
ff2_inp_ptr
:
gelu_inp_ptr
),
inter_b_ptr
,
(
_gelu_checkpoint
?
buf_2
:
ff2_inp_ptr
),
_stream
);
_ff2
.
Forward
(
bsz_seq
,
(
_gelu_checkpoint
?
buf_2
:
ff2_inp_ptr
),
output_w_ptr
,
out_ptr
,
_cublasHandle
);
// layer output dropout.
if
(
_pre_or_postLayerNorm
)
_layer_output_dropout
.
ForwardWithBias
(
bsz_seq
,
out_ptr
,
out_ptr
,
add_res_ptr
,
output_b_ptr
,
_stream
);
else
_layer_output_dropout
.
ForwardWithBias
(
bsz_seq
,
inp_norm_ptr
,
out_ptr
,
ff1_inp_ptr
,
output_b_ptr
,
_stream
);
if
(
!
_pre_or_postLayerNorm
)
{
if
(
_layer_norm
.
UseMean
())
_layer_norm
.
ForwardCheckpoint
(
bsz_seq
,
out_ptr
,
inp_norm_ptr
,
norm_w_ptr
,
norm_b_ptr
,
_stream
,
true
);
else
_layer_norm
.
Forward
(
bsz_seq
,
out_ptr
,
inp_norm_ptr
,
norm_w_ptr
,
norm_b_ptr
,
_stream
,
true
);
}
}
template
<
typename
T
>
void
BertTransformerLayer
<
T
>::
Backward
(
int
bsz
,
const
T
*
grad_output_ptr
,
const
T
*
input_ptr
,
const
T
*
output_ptr
,
const
T
*
inp_norm_ptr
,
const
T
*
q_tf_ptr
,
const
T
*
k_tf_ptr
,
const
T
*
v_tf_ptr
,
const
T
*
soft_out_ptr
,
const
T
*
ctx_bufB_ptr
,
const
T
*
attn_o_inp_ptr
,
const
T
*
add_res_ptr
,
const
T
*
ff1_inp_ptr
,
const
T
*
gelu_inp_ptr
,
const
T
*
ff2_inp_ptr
,
const
T
*
input_mask_ptr
,
const
T
*
attn_qkvw_ptr
,
const
T
*
attn_ow_ptr
,
const
T
*
attn_nw_ptr
,
const
T
*
attn_nb_ptr
,
const
T
*
inter_w_ptr
,
const
T
*
inter_b_ptr
,
const
T
*
output_w_ptr
,
const
T
*
norm_w_ptr
,
const
T
*
norm_b_ptr
,
T
*
grad_input_ptr
,
T
*
grad_attn_qkvw_ptr
,
T
*
grad_attn_qkvb_ptr
,
T
*
grad_attn_ow_ptr
,
T
*
grad_attn_ob_ptr
,
T
*
grad_attn_nw_ptr
,
T
*
grad_attn_nb_ptr
,
T
*
grad_inter_w_ptr
,
T
*
grad_inter_b_ptr
,
T
*
grad_output_w_ptr
,
T
*
grad_output_b_ptr
,
T
*
grad_norm_w_ptr
,
T
*
grad_norm_b_ptr
)
{
cublasSetStream
(
_cublasHandle
,
_stream
);
if
(
!
_stochastic_mode
)
cudaStreamSynchronize
(
_stream
);
T
*
workspace
=
static_cast
<
T
*>
(
Context
::
Instance
().
GetWorkSpace
());
size_t
small_buf_size
=
bsz
*
_seq_length
*
_hidden_size
;
T
*
buf_0
=
workspace
;
T
*
buf_1
=
buf_0
+
small_buf_size
;
T
*
buf_2
=
buf_1
+
small_buf_size
;
T
*
buf_3
=
buf_2
+
small_buf_size
;
T
*
ff2_buf
=
(
_gelu_checkpoint
?
buf_3
+
(
bsz
*
_seq_length
*
_intermediate_size
)
:
buf_3
+
small_buf_size
);
T
*
ctx_bufB_ptr_recomp
=
ff2_buf
+
(
_seq_length
*
_seq_length
*
bsz
*
_heads
);
cudaStream_t
streams
[
2
]
=
{
_stream
,
_stream
};
int
bsz_seq
=
bsz
*
_seq_length
;
int
bsz_heads
=
bsz
*
_heads
;
if
(
!
_pre_or_postLayerNorm
)
{
if
(
_layer_norm
.
UseMean
())
_layer_norm
.
Backward
(
bsz_seq
,
grad_output_ptr
,
norm_w_ptr
,
grad_norm_w_ptr
,
grad_norm_b_ptr
,
streams
,
buf_1
,
inp_norm_ptr
);
else
_layer_norm
.
Backward
(
bsz_seq
,
grad_output_ptr
,
norm_w_ptr
,
norm_b_ptr
,
grad_norm_w_ptr
,
grad_norm_b_ptr
,
streams
,
buf_1
,
output_ptr
);
}
if
(
_pre_or_postLayerNorm
)
_layer_output_dropout
.
Backward
(
bsz_seq
,
buf_0
,
grad_output_ptr
,
_stream
);
else
_layer_output_dropout
.
Backward
(
bsz_seq
,
buf_0
,
buf_1
,
_stream
);
const
T
*
layer_dropout_buf
=
_layer_output_dropout
.
HasDropout
()
?
buf_0
:
(
_pre_or_postLayerNorm
?
grad_output_ptr
:
buf_1
);
if
(
_gelu_checkpoint
)
_gelu
.
ForwardWithBiasAdd
(
bsz_seq
,
ff2_inp_ptr
,
inter_b_ptr
,
buf_2
,
_stream
);
_ff2
.
Backward
(
bsz_seq
,
layer_dropout_buf
,
(
_gelu_checkpoint
?
buf_2
:
ff2_inp_ptr
),
output_w_ptr
,
grad_output_w_ptr
,
grad_output_b_ptr
,
_cublasHandle
,
_stream
,
ff2_buf
);
_gelu
.
Backward
(
bsz_seq
,
ff2_buf
,
(
_gelu_checkpoint
?
ff2_inp_ptr
:
gelu_inp_ptr
),
inter_b_ptr
,
_stream
);
_ff1
.
Backward
(
bsz_seq
,
ff2_buf
,
ff1_inp_ptr
,
inter_w_ptr
,
grad_inter_w_ptr
,
grad_inter_b_ptr
,
_cublasHandle
,
_stream
,
buf_3
);
if
(
!
_pre_or_postLayerNorm
)
launch_fused_add2
<
T
>
(
buf_2
,
buf_3
,
buf_1
,
bsz
,
_seq_length
,
_hidden_size
,
_stream
);
if
(
_pre_or_postLayerNorm
)
{
if
(
_attn_layer_norm
.
UseMean
())
_attn_layer_norm
.
BackwardFusedAdd
(
bsz_seq
,
buf_3
,
grad_output_ptr
,
attn_nw_ptr
,
grad_attn_nw_ptr
,
grad_attn_nb_ptr
,
streams
,
buf_0
,
add_res_ptr
);
else
_attn_layer_norm
.
BackwardFusedAdd
(
bsz_seq
,
buf_3
,
grad_output_ptr
,
attn_nw_ptr
,
attn_nb_ptr
,
grad_attn_nw_ptr
,
grad_attn_nb_ptr
,
streams
,
buf_0
,
ff1_inp_ptr
);
}
else
{
if
(
_attn_layer_norm
.
UseMean
())
_attn_layer_norm
.
Backward
(
bsz_seq
,
buf_2
,
attn_nw_ptr
,
grad_attn_nw_ptr
,
grad_attn_nb_ptr
,
streams
,
buf_0
,
add_res_ptr
);
else
_attn_layer_norm
.
Backward
(
bsz_seq
,
buf_2
,
attn_nw_ptr
,
attn_nb_ptr
,
grad_attn_nw_ptr
,
grad_attn_nb_ptr
,
streams
,
buf_0
,
ff1_inp_ptr
);
}
_attn_output_dropout
.
Backward
(
bsz_seq
,
buf_2
,
buf_0
,
_stream
);
T
*
attn_output_dropout_buf
=
_attn_output_dropout
.
HasDropout
()
?
buf_2
:
buf_0
;
_attn_out_linear
.
Backward
(
bsz_seq
,
attn_output_dropout_buf
,
attn_o_inp_ptr
,
attn_ow_ptr
,
grad_attn_ow_ptr
,
grad_attn_ob_ptr
,
_cublasHandle
,
_stream
,
buf_1
);
launch_transform_0213
<
T
>
(
buf_2
,
buf_1
,
bsz
,
_seq_length
,
_hidden_size
,
_heads
,
_stream
);
if
(
_attn_prob_dropout
.
HasDropout
())
{
if
(
_attn_dropout_checkpoint
)
_attn_prob_dropout
.
Forward
(
bsz_heads
*
_seq_length
,
ctx_bufB_ptr_recomp
,
soft_out_ptr
,
_stream
,
true
);
_attn_context
.
Backward
(
bsz_heads
,
buf_2
,
v_tf_ptr
,
(
_attn_dropout_checkpoint
?
ctx_bufB_ptr_recomp
:
ctx_bufB_ptr
),
_cublasHandle
,
buf_3
,
ff2_buf
);
}
else
_attn_context
.
Backward
(
bsz_heads
,
buf_2
,
v_tf_ptr
,
soft_out_ptr
,
_cublasHandle
,
buf_3
,
ff2_buf
);
_attn_prob_dropout
.
Backward
(
bsz_heads
*
_seq_length
,
ff2_buf
,
_stream
);
_softmax
.
Backward
(
bsz
,
ff2_buf
,
soft_out_ptr
,
_stream
);
_attn_scores
.
Backward
(
bsz_heads
,
ff2_buf
,
k_tf_ptr
,
q_tf_ptr
,
_cublasHandle
,
buf_2
,
buf_1
);
launch_transform4d_0213
(
ff2_buf
,
buf_1
,
bsz
,
_heads
,
_seq_length
,
_hidden_size
,
_stream
,
3
);
if
(
_pre_or_postLayerNorm
)
_qkv_linear
.
Backward
(
bsz_seq
,
ff2_buf
,
inp_norm_ptr
,
attn_qkvw_ptr
,
grad_attn_qkvw_ptr
,
grad_attn_qkvb_ptr
,
_cublasHandle
,
_stream
,
buf_2
);
else
_qkv_linear
.
Backward
(
bsz_seq
,
ff2_buf
,
input_ptr
,
attn_qkvw_ptr
,
grad_attn_qkvw_ptr
,
grad_attn_qkvb_ptr
,
_cublasHandle
,
_stream
,
buf_2
);
if
(
_pre_or_postLayerNorm
)
{
if
(
_layer_norm
.
UseMean
())
_layer_norm
.
BackwardFusedAdd
(
bsz_seq
,
buf_2
,
buf_0
,
norm_w_ptr
,
grad_norm_w_ptr
,
grad_norm_b_ptr
,
streams
,
grad_input_ptr
,
input_ptr
);
else
_layer_norm
.
BackwardFusedAdd
(
bsz_seq
,
buf_2
,
buf_0
,
norm_w_ptr
,
norm_b_ptr
,
grad_norm_w_ptr
,
grad_norm_b_ptr
,
streams
,
grad_input_ptr
,
inp_norm_ptr
);
}
else
launch_fused_add2
<
T
>
(
grad_input_ptr
,
buf_2
,
buf_0
,
bsz
,
_seq_length
,
_hidden_size
,
_stream
);
}
template
<
typename
T
>
void
BertTransformerLayer
<
T
>::
SetTrainingMode
(
bool
training
)
{
// Dropout will be skipped when not in training model.
_attn_prob_dropout
.
SetTrainingMode
(
training
);
_attn_output_dropout
.
SetTrainingMode
(
training
);
_layer_output_dropout
.
SetTrainingMode
(
training
);
}
template
<
typename
T
>
void
BertTransformerLayer
<
T
>::
SetIntermediateBuffers
(
uint8_t
*
attn_prob_dropout_mask_ptr
,
uint8_t
*
attn_output_dropout_mask_ptr
,
uint8_t
*
layer_output_dropout_mask_ptr
,
T
*
attn_layer_norm_var
,
T
*
attn_layer_norm_mean
,
T
*
layer_norm_var
,
T
*
layer_norm_mean
)
{
_attn_prob_dropout
.
SetMask
(
attn_prob_dropout_mask_ptr
);
_attn_output_dropout
.
SetMask
(
attn_output_dropout_mask_ptr
);
_layer_output_dropout
.
SetMask
(
layer_output_dropout_mask_ptr
);
_attn_layer_norm
.
SetVar
(
attn_layer_norm_var
);
_attn_layer_norm
.
SetMean
(
attn_layer_norm_mean
);
_layer_norm
.
SetVar
(
layer_norm_var
);
_layer_norm
.
SetMean
(
layer_norm_mean
);
}
template
<
typename
T
>
void
BertTransformerLayer
<
T
>::
SetSeqLength
(
int
seq_len
)
{
_seq_length
=
seq_len
;
_softmax
.
SetSeqLength
(
_seq_length
);
_attn_prob_dropout
.
SetDimension
(
_seq_length
);
_attn_scores
.
SetConfig
(
_seq_length
,
_seq_length
,
_hidden_size
/
_heads
);
_attn_context
.
SetConfig
(
_hidden_size
/
_heads
,
_seq_length
,
_seq_length
);
}
template
<
typename
T
>
int
create_transformer_layer
(
int
layer_id
,
int
batch_size
,
int
hidden_dim
,
int
num_heads
,
int
intermediate_size
,
float
attn_dropout_ratio
,
float
hidden_dropout_ratio
,
float
layer_norm_eps
,
int
seed
,
bool
pre_or_postLayerNorm
,
bool
test_gemm
,
bool
attn_dropout_checkpoint
,
bool
normalize_invertible
,
bool
gelu_checkpoint
,
bool
stochastic_mode
)
{
Context
::
Instance
().
SetSeed
(
seed
);
Context
::
Instance
().
TestGemmFP16
(
test_gemm
,
batch_size
,
init_seq_length
,
num_heads
,
hidden_dim
/
num_heads
);
auto
layer
=
std
::
make_shared
<
BertTransformerLayer
<
T
>>
(
layer_id
,
batch_size
,
hidden_dim
,
num_heads
,
intermediate_size
,
init_seq_length
,
attn_dropout_ratio
,
hidden_dropout_ratio
,
layer_norm_eps
,
pre_or_postLayerNorm
,
Context
::
Instance
().
GetGemmAlgos
(),
attn_dropout_checkpoint
,
normalize_invertible
,
gelu_checkpoint
,
stochastic_mode
);
s_transformer_layers
[
layer_id
]
=
layer
;
std
::
string
dtype
=
(
std
::
is_same
<
T
,
__half
>::
value
)
?
"half"
:
"float"
;
std
::
cout
<<
"layer #"
<<
layer_id
<<
" is created with date type ["
<<
dtype
<<
"]."
<<
std
::
endl
;
return
0
;
}
template
<
typename
T
>
std
::
vector
<
torch
::
Tensor
>
ds_transformer_forward
(
int
layer_id
,
const
torch
::
Tensor
&
input
,
const
torch
::
Tensor
&
input_mask
,
const
torch
::
Tensor
&
attn_qkvw
,
const
torch
::
Tensor
&
attn_qkvb
,
const
torch
::
Tensor
&
attn_ow
,
const
torch
::
Tensor
&
attn_ob
,
const
torch
::
Tensor
&
attn_nw
,
const
torch
::
Tensor
&
attn_nb
,
const
torch
::
Tensor
&
inter_w
,
const
torch
::
Tensor
&
inter_b
,
const
torch
::
Tensor
&
output_w
,
const
torch
::
Tensor
&
output_b
,
const
torch
::
Tensor
&
norm_w
,
const
torch
::
Tensor
&
norm_b
,
bool
training_mode
,
bool
prelayernorm
,
bool
attn_dropout_checkpoint
,
bool
normalize_invertible
,
bool
gelu_checkpoint
)
{
CHECK_INPUT
(
input
);
CHECK_INPUT
(
input_mask
);
CHECK_INPUT
(
attn_qkvw
);
CHECK_INPUT
(
attn_qkvb
);
CHECK_INPUT
(
attn_ow
);
CHECK_INPUT
(
attn_ob
);
CHECK_INPUT
(
attn_nw
);
CHECK_INPUT
(
attn_nb
);
CHECK_INPUT
(
inter_w
);
CHECK_INPUT
(
inter_b
);
CHECK_INPUT
(
output_w
);
CHECK_INPUT
(
output_b
);
CHECK_INPUT
(
norm_w
);
CHECK_INPUT
(
norm_b
);
int
bsz
=
input
.
size
(
0
);
const
T
*
input_ptr
=
(
const
T
*
)
input
.
data_ptr
();
const
T
*
input_mask_ptr
=
(
const
T
*
)
input_mask
.
data_ptr
();
const
T
*
attn_qkvw_ptr
=
(
const
T
*
)
attn_qkvw
.
data_ptr
();
const
T
*
attn_qkvb_ptr
=
(
const
T
*
)
attn_qkvb
.
data_ptr
();
const
T
*
attn_ow_ptr
=
(
const
T
*
)
attn_ow
.
data_ptr
();
const
T
*
attn_ob_ptr
=
(
const
T
*
)
attn_ob
.
data_ptr
();
const
T
*
attn_nw_ptr
=
(
const
T
*
)
attn_nw
.
data_ptr
();
const
T
*
attn_nb_ptr
=
(
const
T
*
)
attn_nb
.
data_ptr
();
const
T
*
inter_w_ptr
=
(
const
T
*
)
inter_w
.
data_ptr
();
const
T
*
inter_b_ptr
=
(
const
T
*
)
inter_b
.
data_ptr
();
const
T
*
output_w_ptr
=
(
const
T
*
)
output_w
.
data_ptr
();
const
T
*
output_b_ptr
=
(
const
T
*
)
output_b
.
data_ptr
();
const
T
*
norm_w_ptr
=
(
const
T
*
)
norm_w
.
data_ptr
();
const
T
*
norm_b_ptr
=
(
const
T
*
)
norm_b
.
data_ptr
();
auto
output
=
torch
::
empty_like
(
input
);
T
*
out_ptr
=
(
T
*
)
output
.
data_ptr
();
auto
options
=
torch
::
TensorOptions
()
.
dtype
(
input
.
options
().
dtype
())
.
layout
(
torch
::
kStrided
)
.
device
(
torch
::
kCUDA
)
.
requires_grad
(
true
);
auto
uint8_options
=
torch
::
TensorOptions
()
.
dtype
(
torch
::
kInt8
)
.
layout
(
torch
::
kStrided
)
.
device
(
torch
::
kCUDA
)
.
requires_grad
(
false
);
std
::
shared_ptr
<
BertTransformerLayer
<
T
>>
layer
=
std
::
static_pointer_cast
<
BertTransformerLayer
<
T
>>
(
s_transformer_layers
[
layer_id
]);
int
seq_len
=
layer
->
GetSeqLength
();
if
(
input
.
size
(
1
)
!=
seq_len
)
{
seq_len
=
input
.
size
(
1
);
layer
->
SetSeqLength
(
seq_len
);
}
auto
workspace
=
torch
::
empty
({
get_workspace_size
<
T
>
(
bsz
,
seq_len
,
layer
->
GetHiddenSize
(),
layer
->
GetIntermediateSize
(),
layer
->
GetNumHeads
(),
layer
->
IsTrainingMode
(),
layer
->
GeluCheckpoint
())},
options
);
Context
::
Instance
().
SetWorkSpace
((
T
*
)
workspace
.
data_ptr
());
auto
inp_norm
=
((
prelayernorm
||
!
normalize_invertible
)
?
torch
::
empty_like
(
input
)
:
output
);
auto
add_res
=
(
normalize_invertible
?
inp_norm
:
torch
::
empty_like
(
input
));
auto
attn_o_inp
=
torch
::
empty_like
(
input
);
auto
qkv_tf
=
torch
::
empty
({(
bsz
*
seq_len
),
output_w
.
size
(
0
)
*
3
},
options
);
auto
attn_prob_dropout_mask
=
torch
::
empty
({(
bsz
*
layer
->
GetNumHeads
()
*
seq_len
),
seq_len
},
uint8_options
);
auto
attn_output_dropout_mask
=
torch
::
empty
({(
bsz
*
seq_len
),
layer
->
GetHiddenSize
()},
uint8_options
);
auto
layer_output_dropout_mask
=
torch
::
empty
({(
bsz
*
seq_len
),
layer
->
GetHiddenSize
()},
uint8_options
);
auto
attn_layer_norm_var
=
torch
::
empty
({(
bsz
*
seq_len
)},
options
);
auto
attn_layer_norm_mean
=
torch
::
empty
({(
bsz
*
seq_len
)},
options
);
auto
layer_norm_var
=
torch
::
empty
({(
bsz
*
seq_len
)},
options
);
auto
layer_norm_mean
=
torch
::
empty
({(
bsz
*
seq_len
)},
options
);
T
*
inp_norm_ptr
=
(
T
*
)
inp_norm
.
data_ptr
();
T
*
add_res_ptr
=
(
T
*
)
add_res
.
data_ptr
();
T
*
q_tf_ptr
=
(
T
*
)
qkv_tf
.
data_ptr
();
T
*
k_tf_ptr
=
q_tf_ptr
+
(
bsz
*
seq_len
*
output_w
.
size
(
0
));
//(T*)k_tf.data_ptr();
T
*
v_tf_ptr
=
k_tf_ptr
+
(
bsz
*
seq_len
*
output_w
.
size
(
0
));
//(T*)v_tf.data_ptr();
T
*
attn_o_inp_ptr
=
(
T
*
)
attn_o_inp
.
data_ptr
();
torch
::
Tensor
ff2_inp
=
torch
::
empty
({(
bsz
*
seq_len
),
output_w
.
size
(
1
)},
options
);
torch
::
Tensor
gelu_inp
=
(
gelu_checkpoint
?
ff2_inp
:
torch
::
empty
({(
bsz
*
seq_len
),
output_w
.
size
(
1
)},
options
));
auto
ff1_inp
=
torch
::
empty_like
(
input
);
T
*
ff2_inp_ptr
=
(
T
*
)
ff2_inp
.
data_ptr
();
T
*
gelu_inp_ptr
=
(
T
*
)
gelu_inp
.
data_ptr
();
T
*
ff1_inp_ptr
=
(
T
*
)
ff1_inp
.
data_ptr
();
torch
::
Tensor
soft_out
=
torch
::
empty
({(
bsz
*
layer
->
GetNumHeads
()
*
seq_len
),
seq_len
},
options
);
torch
::
Tensor
ctx_bufB
=
(
attn_dropout_checkpoint
?
soft_out
:
torch
::
empty
({(
bsz
*
layer
->
GetNumHeads
()
*
seq_len
),
seq_len
},
options
));
T
*
soft_out_ptr
=
(
T
*
)
soft_out
.
data_ptr
();
T
*
ctx_bufB_ptr
=
(
T
*
)
ctx_bufB
.
data_ptr
();
layer
->
SetTrainingMode
(
training_mode
);
layer
->
SetIntermediateBuffers
((
uint8_t
*
)
attn_prob_dropout_mask
.
data_ptr
(),
(
uint8_t
*
)
attn_output_dropout_mask
.
data_ptr
(),
(
uint8_t
*
)
layer_output_dropout_mask
.
data_ptr
(),
(
T
*
)
attn_layer_norm_var
.
data_ptr
(),
(
T
*
)
attn_layer_norm_mean
.
data_ptr
(),
(
T
*
)
layer_norm_var
.
data_ptr
(),
(
T
*
)
layer_norm_mean
.
data_ptr
());
layer
->
Forward
(
bsz
,
input_ptr
,
input_mask_ptr
,
attn_qkvw_ptr
,
attn_qkvb_ptr
,
attn_ow_ptr
,
attn_ob_ptr
,
attn_nw_ptr
,
attn_nb_ptr
,
inter_w_ptr
,
inter_b_ptr
,
output_w_ptr
,
output_b_ptr
,
norm_w_ptr
,
norm_b_ptr
,
out_ptr
,
inp_norm_ptr
,
q_tf_ptr
,
k_tf_ptr
,
v_tf_ptr
,
soft_out_ptr
,
ctx_bufB_ptr
,
attn_o_inp_ptr
,
add_res_ptr
,
ff1_inp_ptr
,
gelu_inp_ptr
,
ff2_inp_ptr
);
return
{
output
,
inp_norm
,
qkv_tf
,
soft_out
,
ctx_bufB
,
attn_o_inp
,
add_res
,
ff1_inp
,
gelu_inp
,
ff2_inp
,
attn_prob_dropout_mask
,
attn_output_dropout_mask
,
layer_output_dropout_mask
,
attn_layer_norm_var
,
attn_layer_norm_mean
,
layer_norm_var
,
layer_norm_mean
};
}
template
<
typename
T
>
std
::
vector
<
torch
::
Tensor
>
ds_transformer_backward
(
int
layer_id
,
const
torch
::
Tensor
&
grad_output
,
const
torch
::
Tensor
&
output
,
const
torch
::
Tensor
&
inp_norm
,
const
torch
::
Tensor
&
qkv_tf
,
const
torch
::
Tensor
&
soft_out
,
const
torch
::
Tensor
&
ctx_bufB
,
const
torch
::
Tensor
&
attn_o_inp
,
const
torch
::
Tensor
&
add_res
,
const
torch
::
Tensor
&
ff1_inp
,
const
torch
::
Tensor
&
gelu_inp
,
const
torch
::
Tensor
&
ff2_inp
,
const
torch
::
Tensor
&
attn_prob_dropout_mask
,
const
torch
::
Tensor
&
attn_output_dropout_mask
,
const
torch
::
Tensor
&
layer_output_dropout_mask
,
const
torch
::
Tensor
&
attn_layer_norm_var
,
const
torch
::
Tensor
&
attn_layer_norm_mean
,
const
torch
::
Tensor
&
layer_norm_var
,
const
torch
::
Tensor
&
layer_norm_mean
,
const
torch
::
Tensor
&
input
,
const
torch
::
Tensor
&
input_mask
,
const
torch
::
Tensor
&
attn_qkvw
,
const
torch
::
Tensor
&
attn_qkvb
,
const
torch
::
Tensor
&
attn_ow
,
const
torch
::
Tensor
&
attn_ob
,
const
torch
::
Tensor
&
attn_nw
,
const
torch
::
Tensor
&
attn_nb
,
const
torch
::
Tensor
&
inter_w
,
const
torch
::
Tensor
&
inter_b
,
const
torch
::
Tensor
&
output_w
,
const
torch
::
Tensor
&
output_b
,
const
torch
::
Tensor
&
norm_w
,
const
torch
::
Tensor
&
norm_b
)
{
auto
g_output
=
grad_output
.
contiguous
();
CHECK_INPUT
(
g_output
);
CHECK_INPUT
(
output
);
CHECK_INPUT
(
inp_norm
);
CHECK_INPUT
(
qkv_tf
);
CHECK_INPUT
(
add_res
);
CHECK_INPUT
(
soft_out
);
CHECK_INPUT
(
ctx_bufB
);
CHECK_INPUT
(
attn_o_inp
);
CHECK_INPUT
(
ff1_inp
);
CHECK_INPUT
(
gelu_inp
);
CHECK_INPUT
(
ff2_inp
);
CHECK_INPUT
(
input
);
CHECK_INPUT
(
input_mask
);
CHECK_INPUT
(
attn_qkvw
);
CHECK_INPUT
(
attn_qkvb
);
CHECK_INPUT
(
attn_ow
);
CHECK_INPUT
(
attn_ob
);
CHECK_INPUT
(
attn_nw
);
CHECK_INPUT
(
attn_nb
);
CHECK_INPUT
(
inter_w
);
CHECK_INPUT
(
inter_b
);
CHECK_INPUT
(
output_w
);
CHECK_INPUT
(
output_b
);
CHECK_INPUT
(
norm_w
);
CHECK_INPUT
(
norm_b
);
int
bsz
=
g_output
.
size
(
0
);
std
::
shared_ptr
<
BertTransformerLayer
<
T
>>
layer
=
std
::
static_pointer_cast
<
BertTransformerLayer
<
T
>>
(
s_transformer_layers
[
layer_id
]);
int
seq_len
=
layer
->
GetSeqLength
();
if
(
g_output
.
size
(
1
)
!=
seq_len
)
{
seq_len
=
g_output
.
size
(
1
);
layer
->
SetSeqLength
(
seq_len
);
}
auto
options
=
torch
::
TensorOptions
()
.
dtype
(
g_output
.
options
().
dtype
())
.
layout
(
torch
::
kStrided
)
.
device
(
torch
::
kCUDA
)
.
requires_grad
(
true
);
auto
workspace
=
torch
::
empty
({
get_workspace_size
<
T
>
(
bsz
,
seq_len
,
layer
->
GetHiddenSize
(),
layer
->
GetIntermediateSize
(),
layer
->
GetNumHeads
(),
layer
->
IsTrainingMode
(),
layer
->
GeluCheckpoint
())},
options
);
Context
::
Instance
().
SetWorkSpace
((
T
*
)
workspace
.
data_ptr
());
auto
grad_input
=
torch
::
empty_like
(
input
);
auto
grad_attn_qkvw
=
torch
::
empty_like
(
attn_qkvw
);
auto
grad_attn_qkvb
=
torch
::
empty_like
(
attn_qkvb
);
auto
grad_attn_ow
=
torch
::
empty_like
(
attn_ow
);
auto
grad_attn_ob
=
torch
::
empty_like
(
attn_ob
);
auto
grad_attn_nw
=
torch
::
empty_like
(
attn_nw
);
auto
grad_attn_nb
=
torch
::
empty_like
(
attn_nb
);
auto
grad_inter_w
=
torch
::
empty_like
(
inter_w
);
auto
grad_inter_b
=
torch
::
empty_like
(
inter_b
);
auto
grad_output_w
=
torch
::
empty_like
(
output_w
);
auto
grad_output_b
=
torch
::
empty_like
(
output_b
);
auto
grad_norm_w
=
torch
::
empty_like
(
norm_w
);
auto
grad_norm_b
=
torch
::
empty_like
(
norm_b
);
// inputs.
const
T
*
grad_output_ptr
=
(
const
T
*
)
g_output
.
data_ptr
();
const
T
*
input_ptr
=
(
const
T
*
)
input
.
data_ptr
();
const
T
*
output_ptr
=
(
const
T
*
)
output
.
data_ptr
();
const
T
*
inp_norm_ptr
=
(
const
T
*
)
inp_norm
.
data_ptr
();
const
T
*
q_tf_ptr
=
(
const
T
*
)
qkv_tf
.
data_ptr
();
const
T
*
add_res_ptr
=
(
const
T
*
)
add_res
.
data_ptr
();
const
T
*
k_tf_ptr
=
q_tf_ptr
+
(
bsz
*
layer
->
GetSeqLength
()
*
output_w
.
size
(
0
));
//(const T*)k_tf.data_ptr();
const
T
*
v_tf_ptr
=
k_tf_ptr
+
(
bsz
*
layer
->
GetSeqLength
()
*
output_w
.
size
(
0
));
//(const T*)v_tf.data_ptr();
const
T
*
ff1_inp_ptr
=
(
const
T
*
)
ff1_inp
.
data_ptr
();
const
T
*
gelu_inp_ptr
=
(
const
T
*
)
gelu_inp
.
data_ptr
();
const
T
*
ff2_inp_ptr
=
(
const
T
*
)
ff2_inp
.
data_ptr
();
const
T
*
ctx_bufB_ptr
=
(
const
T
*
)
ctx_bufB
.
data_ptr
();
const
T
*
soft_out_ptr
=
(
const
T
*
)
soft_out
.
data_ptr
();
const
T
*
attn_o_inp_ptr
=
(
const
T
*
)
attn_o_inp
.
data_ptr
();
const
T
*
input_mask_ptr
=
(
const
T
*
)
input_mask
.
data_ptr
();
const
T
*
attn_qkvw_ptr
=
(
const
T
*
)
attn_qkvw
.
data_ptr
();
const
T
*
attn_ow_ptr
=
(
const
T
*
)
attn_ow
.
data_ptr
();
const
T
*
attn_nw_ptr
=
(
const
T
*
)
attn_nw
.
data_ptr
();
const
T
*
attn_nb_ptr
=
(
const
T
*
)
attn_nb
.
data_ptr
();
const
T
*
inter_w_ptr
=
(
const
T
*
)
inter_w
.
data_ptr
();
const
T
*
inter_b_ptr
=
(
const
T
*
)
inter_b
.
data_ptr
();
const
T
*
output_w_ptr
=
(
const
T
*
)
output_w
.
data_ptr
();
const
T
*
norm_w_ptr
=
(
const
T
*
)
norm_w
.
data_ptr
();
const
T
*
norm_b_ptr
=
(
const
T
*
)
norm_b
.
data_ptr
();
// outputs.
T
*
grad_input_ptr
=
(
T
*
)
grad_input
.
data_ptr
();
T
*
grad_attn_qkvw_ptr
=
(
T
*
)
grad_attn_qkvw
.
data_ptr
();
T
*
grad_attn_qkvb_ptr
=
(
T
*
)
grad_attn_qkvb
.
data_ptr
();
T
*
grad_attn_ow_ptr
=
(
T
*
)
grad_attn_ow
.
data_ptr
();
T
*
grad_attn_ob_ptr
=
(
T
*
)
grad_attn_ob
.
data_ptr
();
T
*
grad_attn_nw_ptr
=
(
T
*
)
grad_attn_nw
.
data_ptr
();
T
*
grad_attn_nb_ptr
=
(
T
*
)
grad_attn_nb
.
data_ptr
();
T
*
grad_inter_w_ptr
=
(
T
*
)
grad_inter_w
.
data_ptr
();
T
*
grad_inter_b_ptr
=
(
T
*
)
grad_inter_b
.
data_ptr
();
T
*
grad_output_w_ptr
=
(
T
*
)
grad_output_w
.
data_ptr
();
T
*
grad_output_b_ptr
=
(
T
*
)
grad_output_b
.
data_ptr
();
T
*
grad_norm_w_ptr
=
(
T
*
)
grad_norm_w
.
data_ptr
();
T
*
grad_norm_b_ptr
=
(
T
*
)
grad_norm_b
.
data_ptr
();
layer
->
SetIntermediateBuffers
((
uint8_t
*
)
attn_prob_dropout_mask
.
data_ptr
(),
(
uint8_t
*
)
attn_output_dropout_mask
.
data_ptr
(),
(
uint8_t
*
)
layer_output_dropout_mask
.
data_ptr
(),
(
T
*
)
attn_layer_norm_var
.
data_ptr
(),
(
T
*
)
attn_layer_norm_mean
.
data_ptr
(),
(
T
*
)
layer_norm_var
.
data_ptr
(),
(
T
*
)
layer_norm_mean
.
data_ptr
());
layer
->
Backward
(
bsz
,
grad_output_ptr
,
input_ptr
,
output_ptr
,
inp_norm_ptr
,
q_tf_ptr
,
k_tf_ptr
,
v_tf_ptr
,
soft_out_ptr
,
ctx_bufB_ptr
,
attn_o_inp_ptr
,
add_res_ptr
,
ff1_inp_ptr
,
gelu_inp_ptr
,
ff2_inp_ptr
,
input_mask_ptr
,
attn_qkvw_ptr
,
attn_ow_ptr
,
attn_nw_ptr
,
attn_nb_ptr
,
inter_w_ptr
,
inter_b_ptr
,
output_w_ptr
,
norm_w_ptr
,
norm_b_ptr
,
grad_input_ptr
,
grad_attn_qkvw_ptr
,
grad_attn_qkvb_ptr
,
grad_attn_ow_ptr
,
grad_attn_ob_ptr
,
grad_attn_nw_ptr
,
grad_attn_nb_ptr
,
grad_inter_w_ptr
,
grad_inter_b_ptr
,
grad_output_w_ptr
,
grad_output_b_ptr
,
grad_norm_w_ptr
,
grad_norm_b_ptr
);
return
{
grad_input
,
grad_attn_qkvw
,
grad_attn_qkvb
,
grad_attn_ow
,
grad_attn_ob
,
grad_attn_nw
,
grad_attn_nb
,
grad_inter_w
,
grad_inter_b
,
grad_output_w
,
grad_output_b
,
grad_norm_w
,
grad_norm_b
};
}
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"forward_fp32"
,
&
ds_transformer_forward
<
float
>
,
"DeepSpeed Transformer forward with fp32 (CUDA)"
);
m
.
def
(
"forward_fp16"
,
&
ds_transformer_forward
<
__half
>
,
"DeepSpeed Transformer forward with fp16 (CUDA)"
);
m
.
def
(
"backward_fp32"
,
&
ds_transformer_backward
<
float
>
,
"DeepSpeed Transformer backward with fp32 (CUDA)"
);
m
.
def
(
"backward_fp16"
,
&
ds_transformer_backward
<
__half
>
,
"DeepSpeed Transformer backward with fp16 (CUDA)"
);
m
.
def
(
"create_transformer_layer_fp32"
,
&
create_transformer_layer
<
float
>
,
"Create DeepSpeed Transformer Transformer Layer with fp32 (CUDA)"
);
m
.
def
(
"create_transformer_layer_fp16"
,
&
create_transformer_layer
<
__half
>
,
"Create DeepSpeed Transformer Transformer Layer with fp16 (CUDA)"
);
}
deepspeed/ops/csrc/transformer/gelu_kernels.cu
0 → 100644
View file @
eadbbe09
#include "custom_cuda_layers.h"
inline
__device__
float
gelu
(
const
float
x
)
{
const
float
sqrt_param
=
0.79788456080286535587989211986876
f
;
const
float
mul_param
=
0.044715
;
return
x
*
0.5
f
*
(
1.0
f
+
tanhf
(
sqrt_param
*
(
x
+
mul_param
*
x
*
x
*
x
)));
}
inline
__device__
float
d_gelu
(
const
float
x
)
{
const
float
sqrt_param
=
0.79788456080286535587989211986876
f
;
const
float
mul_param
=
0.044715
;
float
x2mul
=
x
*
x
*
mul_param
;
float
tan_h
=
tanhf
(
sqrt_param
*
(
x
+
x
*
x2mul
));
float
dg1
=
0.5
f
*
(
1.0
f
+
tan_h
);
float
dg2
=
x
*
0.5
f
*
sqrt_param
*
(
1
-
tan_h
*
tan_h
);
float
dg3
=
dg2
*
3
*
x2mul
;
return
(
dg1
+
dg2
+
dg3
);
}
/*
Fused bias add with GELU
Loads a vector of 4 elements each iteration, for stride
iterations. It was written with the intention to launch 256 thread
threadblocks, so to launch for bert-large, we would set ITERATIONS
to 4. This is currently done automatically as a heuristic, setting
the number of iterations as blocks of 1024.
For FP16, the values are loaded from memory as __half, but converted
to FP32 for the arithmetic itself, to prevent numerous overflow on
the intermediate hyperbolic tangent, since there's no intrinsic
that computes it directly.
*/
__global__
void
gelu_kernel
(
const
float
*
input
,
float
*
vals
,
int
intermediate_size
)
{
int
row
=
blockIdx
.
x
;
int
id
=
threadIdx
.
x
;
int
loop_stride
=
blockDim
.
x
;
int
iterations
=
intermediate_size
/
blockDim
.
x
/
4
;
int
row_stride
=
intermediate_size
/
4
;
const
float4
*
input_cast
=
reinterpret_cast
<
const
float4
*>
(
input
);
float4
*
vals_cast
=
reinterpret_cast
<
float4
*>
(
vals
);
for
(
int
i
=
0
;
i
<
iterations
;
i
++
)
{
if
(
i
*
loop_stride
+
id
<
row_stride
)
{
float4
data
=
input_cast
[
row
*
row_stride
+
i
*
loop_stride
+
id
];
data
.
x
=
gelu
(
data
.
x
);
data
.
y
=
gelu
(
data
.
y
);
data
.
z
=
gelu
(
data
.
z
);
data
.
w
=
gelu
(
data
.
w
);
vals_cast
[
row
*
row_stride
+
i
*
loop_stride
+
id
]
=
data
;
}
}
}
__global__
void
gelu_kernel
(
const
__half
*
input
,
__half
*
vals
,
int
intermediate_size
)
{
#if __CUDA_ARCH__ >= 700
int
row
=
blockIdx
.
x
;
int
id
=
threadIdx
.
x
;
int
loop_stride
=
blockDim
.
x
;
int
iterations
=
intermediate_size
/
blockDim
.
x
/
4
;
int
row_stride
=
intermediate_size
/
4
;
const
float2
*
input_cast
=
reinterpret_cast
<
const
float2
*>
(
input
);
float2
*
vals_cast
=
reinterpret_cast
<
float2
*>
(
vals
);
for
(
int
i
=
0
;
i
<
iterations
;
i
++
)
{
if
(
i
*
loop_stride
+
id
<
row_stride
)
{
float2
vals_vec
=
input_cast
[
row
*
row_stride
+
i
*
loop_stride
+
id
];
__half2
*
vals_half
=
reinterpret_cast
<
__half2
*>
(
&
vals_vec
);
float2
low_data
=
__half22float2
(
vals_half
[
0
]);
float2
high_data
=
__half22float2
(
vals_half
[
1
]);
low_data
.
x
=
gelu
(
low_data
.
x
);
low_data
.
y
=
gelu
(
low_data
.
y
);
high_data
.
x
=
gelu
(
high_data
.
x
);
high_data
.
y
=
gelu
(
high_data
.
y
);
vals_half
[
0
]
=
__float22half2_rn
(
low_data
);
vals_half
[
1
]
=
__float22half2_rn
(
high_data
);
vals_cast
[
row
*
row_stride
+
i
*
loop_stride
+
id
]
=
vals_vec
;
}
}
#endif
}
__global__
void
fused_bias_gelu
(
const
float
*
input
,
const
float
*
bias
,
float
*
vals
,
int
intermediate_size
)
{
int
row
=
blockIdx
.
x
;
int
id
=
threadIdx
.
x
;
int
loop_stride
=
blockDim
.
x
;
int
iterations
=
intermediate_size
/
blockDim
.
x
/
4
;
int
row_stride
=
intermediate_size
/
4
;
const
float4
*
input_cast
=
reinterpret_cast
<
const
float4
*>
(
input
);
float4
*
vals_cast
=
reinterpret_cast
<
float4
*>
(
vals
);
const
float4
*
bias_cast
=
reinterpret_cast
<
const
float4
*>
(
bias
);
for
(
int
i
=
0
;
i
<
iterations
;
i
++
)
{
if
(
i
*
loop_stride
+
id
<
row_stride
)
{
float4
data
=
input_cast
[
row
*
row_stride
+
i
*
loop_stride
+
id
];
float4
bias_data
=
bias_cast
[
i
*
loop_stride
+
id
];
data
.
x
+=
bias_data
.
x
;
data
.
y
+=
bias_data
.
y
;
data
.
z
+=
bias_data
.
z
;
data
.
w
+=
bias_data
.
w
;
data
.
x
=
gelu
(
data
.
x
);
data
.
y
=
gelu
(
data
.
y
);
data
.
z
=
gelu
(
data
.
z
);
data
.
w
=
gelu
(
data
.
w
);
vals_cast
[
row
*
row_stride
+
i
*
loop_stride
+
id
]
=
data
;
}
}
}
__global__
void
fused_bias_gelu
(
const
__half
*
input
,
const
__half
*
bias
,
__half
*
vals
,
int
intermediate_size
)
{
#if __CUDA_ARCH__ >= 700
int
row
=
blockIdx
.
x
;
int
id
=
threadIdx
.
x
;
int
loop_stride
=
blockDim
.
x
;
int
iterations
=
intermediate_size
/
blockDim
.
x
/
4
;
int
row_stride
=
intermediate_size
/
4
;
const
float2
*
input_cast
=
reinterpret_cast
<
const
float2
*>
(
input
);
float2
*
vals_cast
=
reinterpret_cast
<
float2
*>
(
vals
);
const
float2
*
bias_cast
=
reinterpret_cast
<
const
float2
*>
(
bias
);
for
(
int
i
=
0
;
i
<
iterations
;
i
++
)
{
if
(
i
*
loop_stride
+
id
<
row_stride
)
{
float2
vals_vec
=
input_cast
[
row
*
row_stride
+
i
*
loop_stride
+
id
];
float2
bias_vec
=
bias_cast
[
i
*
loop_stride
+
id
];
__half2
*
vals_half
=
reinterpret_cast
<
__half2
*>
(
&
vals_vec
);
__half2
*
bias_half
=
reinterpret_cast
<
__half2
*>
(
&
bias_vec
);
float2
low_data
=
__half22float2
(
vals_half
[
0
]);
float2
high_data
=
__half22float2
(
vals_half
[
1
]);
float2
low_bias
=
__half22float2
(
bias_half
[
0
]);
float2
high_bias
=
__half22float2
(
bias_half
[
1
]);
low_data
.
x
+=
low_bias
.
x
;
low_data
.
y
+=
low_bias
.
y
;
high_data
.
x
+=
high_bias
.
x
;
high_data
.
y
+=
high_bias
.
y
;
low_data
.
x
=
gelu
(
low_data
.
x
);
low_data
.
y
=
gelu
(
low_data
.
y
);
high_data
.
x
=
gelu
(
high_data
.
x
);
high_data
.
y
=
gelu
(
high_data
.
y
);
vals_half
[
0
]
=
__float22half2_rn
(
low_data
);
vals_half
[
1
]
=
__float22half2_rn
(
high_data
);
vals_cast
[
row
*
row_stride
+
i
*
loop_stride
+
id
]
=
vals_vec
;
}
}
#endif
}
__global__
void
d_gelu_func
(
float
*
d_output
,
const
float
*
gelu_input
,
const
float
*
bias
,
int
intermediate_size
)
{
int
row
=
blockIdx
.
x
;
int
id
=
threadIdx
.
x
;
int
loop_stride
=
blockDim
.
x
;
int
iterations
=
intermediate_size
/
blockDim
.
x
/
4
;
int
row_stride
=
intermediate_size
/
4
;
float4
*
d_output_cast
=
reinterpret_cast
<
float4
*>
(
d_output
);
const
float4
*
gelu_input_cast
=
reinterpret_cast
<
const
float4
*>
(
gelu_input
);
const
float4
*
bias_cast
=
reinterpret_cast
<
const
float4
*>
(
bias
);
for
(
int
i
=
0
;
i
<
iterations
;
i
++
)
{
if
(
i
*
loop_stride
+
id
<
row_stride
)
{
float4
output_data
=
d_output_cast
[
row
*
row_stride
+
i
*
loop_stride
+
id
];
float4
gelu_input_data
=
gelu_input_cast
[
row
*
row_stride
+
i
*
loop_stride
+
id
];
float4
bias_data
=
bias_cast
[
i
*
loop_stride
+
id
];
gelu_input_data
.
x
+=
bias_data
.
x
;
gelu_input_data
.
y
+=
bias_data
.
y
;
gelu_input_data
.
z
+=
bias_data
.
z
;
gelu_input_data
.
w
+=
bias_data
.
w
;
output_data
.
x
*=
d_gelu
(
gelu_input_data
.
x
);
output_data
.
y
*=
d_gelu
(
gelu_input_data
.
y
);
output_data
.
z
*=
d_gelu
(
gelu_input_data
.
z
);
output_data
.
w
*=
d_gelu
(
gelu_input_data
.
w
);
d_output_cast
[
row
*
row_stride
+
i
*
loop_stride
+
id
]
=
output_data
;
}
}
}
__global__
void
d_gelu_func
(
__half
*
d_output
,
const
__half
*
gelu_input
,
const
__half
*
bias
,
int
intermediate_size
)
{
#if __CUDA_ARCH__ >= 700
int
row
=
blockIdx
.
x
;
int
id
=
threadIdx
.
x
;
int
loop_stride
=
blockDim
.
x
;
int
iterations
=
intermediate_size
/
blockDim
.
x
/
4
;
int
row_stride
=
intermediate_size
/
4
;
float2
*
d_output_cast
=
reinterpret_cast
<
float2
*>
(
d_output
);
const
float2
*
gelu_input_cast
=
reinterpret_cast
<
const
float2
*>
(
gelu_input
);
const
float2
*
bias_cast
=
reinterpret_cast
<
const
float2
*>
(
bias
);
#pragma unroll
for
(
int
i
=
0
;
i
<
iterations
;
i
++
)
{
if
(
i
*
loop_stride
+
id
<
row_stride
)
{
float2
output_data
=
d_output_cast
[
row
*
row_stride
+
i
*
loop_stride
+
id
];
float2
gelu_input_data
=
gelu_input_cast
[
row
*
row_stride
+
i
*
loop_stride
+
id
];
float2
bias_vec
=
bias_cast
[
i
*
loop_stride
+
id
];
__half2
*
output_data_half
=
reinterpret_cast
<
__half2
*>
(
&
output_data
);
__half2
*
gelu_input_data_half
=
reinterpret_cast
<
__half2
*>
(
&
gelu_input_data
);
__half2
*
bias_half
=
reinterpret_cast
<
__half2
*>
(
&
bias_vec
);
float2
output_half_0
=
__half22float2
(
output_data_half
[
0
]);
float2
output_half_1
=
__half22float2
(
output_data_half
[
1
]);
float2
gelu_input_half_0
=
__half22float2
(
gelu_input_data_half
[
0
]);
float2
gelu_input_half_1
=
__half22float2
(
gelu_input_data_half
[
1
]);
float2
bias_half_0
=
__half22float2
(
bias_half
[
0
]);
float2
bias_half_1
=
__half22float2
(
bias_half
[
1
]);
gelu_input_half_0
.
x
+=
bias_half_0
.
x
;
gelu_input_half_0
.
y
+=
bias_half_0
.
y
;
gelu_input_half_1
.
x
+=
bias_half_1
.
x
;
gelu_input_half_1
.
y
+=
bias_half_1
.
y
;
output_half_0
.
x
*=
d_gelu
(
gelu_input_half_0
.
x
);
output_half_0
.
y
*=
d_gelu
(
gelu_input_half_0
.
y
);
output_half_1
.
x
*=
d_gelu
(
gelu_input_half_1
.
x
);
output_half_1
.
y
*=
d_gelu
(
gelu_input_half_1
.
y
);
float2
result
;
__half2
*
result_half2
=
reinterpret_cast
<
__half2
*>
(
&
result
);
result_half2
[
0
]
=
__float22half2_rn
(
output_half_0
);
result_half2
[
1
]
=
__float22half2_rn
(
output_half_1
);
d_output_cast
[
row
*
row_stride
+
i
*
loop_stride
+
id
]
=
result
;
}
}
#endif
}
template
<
typename
T
>
void
launch_bias_gelu
(
const
T
*
input
,
const
T
*
bias
,
T
*
output
,
int
intermediate_size
,
int
batch_size
,
cudaStream_t
stream
)
{
int
iterations
=
(
intermediate_size
+
1023
)
/
1024
;
int
threads
=
intermediate_size
/
iterations
/
4
;
dim3
block_dims
(
threads
);
dim3
grid_dims
(
batch_size
);
fused_bias_gelu
<<<
grid_dims
,
block_dims
,
0
,
stream
>>>
(
input
,
bias
,
output
,
intermediate_size
);
}
template
<
typename
T
>
void
launch_gelu
(
const
T
*
input
,
T
*
output
,
int
intermediate_size
,
int
batch_size
,
cudaStream_t
stream
)
{
int
iterations
=
(
intermediate_size
+
1023
)
/
1024
;
int
threads
=
intermediate_size
/
iterations
/
4
;
dim3
block_dims
(
threads
);
dim3
grid_dims
(
batch_size
);
gelu_kernel
<<<
grid_dims
,
block_dims
,
0
,
stream
>>>
(
input
,
output
,
intermediate_size
);
}
template
void
launch_bias_gelu
<
float
>(
const
float
*
,
const
float
*
,
float
*
,
int
,
int
,
cudaStream_t
);
template
void
launch_bias_gelu
<
__half
>(
const
__half
*
,
const
__half
*
,
__half
*
,
int
,
int
,
cudaStream_t
);
template
void
launch_gelu
<
float
>(
const
float
*
,
float
*
,
int
,
int
,
cudaStream_t
);
template
void
launch_gelu
<
__half
>(
const
__half
*
,
__half
*
,
int
,
int
,
cudaStream_t
);
template
<
typename
T
>
void
launch_d_gelu
(
T
*
d_output
,
const
T
*
input
,
const
T
*
bias
,
int
intermediate_size
,
int
batch_size
,
cudaStream_t
stream
)
{
int
iterations
=
(
intermediate_size
+
1023
)
/
1024
;
int
threads
=
intermediate_size
/
iterations
/
4
;
dim3
block_dims
(
threads
);
dim3
grid_dims
(
batch_size
);
d_gelu_func
<<<
grid_dims
,
block_dims
,
0
,
stream
>>>
(
d_output
,
input
,
bias
,
intermediate_size
);
}
template
void
launch_d_gelu
<
float
>(
float
*
,
const
float
*
,
const
float
*
,
int
,
int
,
cudaStream_t
);
template
void
launch_d_gelu
<
__half
>(
__half
*
,
const
__half
*
,
const
__half
*
,
int
,
int
,
cudaStream_t
);
deepspeed/ops/csrc/transformer/general_kernels.cu
0 → 100644
View file @
eadbbe09
#include "general_kernels.h"
namespace
cg
=
cooperative_groups
;
template
<
typename
T
>
__global__
void
column_sum_reduce
(
const
T
*
__restrict__
inp
,
T
*
__restrict__
out
,
int
rows
,
int
width
)
{
__shared__
float
tile
[
TILE_DIM
][
TILE_DIM
+
1
];
cg
::
thread_block
b
=
cg
::
this_thread_block
();
cg
::
thread_block_tile
<
TILE_DIM
>
g
=
cg
::
tiled_partition
<
TILE_DIM
>
(
b
);
int
idx
=
blockDim
.
x
*
blockIdx
.
x
+
threadIdx
.
x
;
int
y_stride
=
width
*
TILE_DIM
;
float
localSum
=
0
;
// Loop across matrix height
if
(
idx
<
width
)
{
int
offset
=
threadIdx
.
y
*
width
+
idx
;
for
(
int
r
=
threadIdx
.
y
;
r
<
rows
;
r
+=
TILE_DIM
)
{
localSum
+=
(
float
)
inp
[
offset
];
offset
+=
y_stride
;
}
}
tile
[
threadIdx
.
x
][
threadIdx
.
y
]
=
localSum
;
__syncthreads
();
// Sum the shared buffer.
float
sum
=
tile
[
threadIdx
.
y
][
threadIdx
.
x
];
#ifndef __STOCHASTIC_MODE__
__syncthreads
();
#endif
for
(
int
i
=
1
;
i
<
TILE_DIM
;
i
<<=
1
)
sum
+=
g
.
shfl_down
(
sum
,
i
);
if
(
threadIdx
.
x
==
0
)
{
int
pos
=
blockIdx
.
x
*
TILE_DIM
+
threadIdx
.
y
;
if
(
pos
<
width
)
out
[
pos
]
=
sum
;
}
}
template
<
typename
T
>
void
launch_fuse_transpose_bias_kernel
(
const
T
*
inp
,
T
*
out
,
int
rows
,
int
cols
,
cudaStream_t
stream
);
template
<
>
void
launch_fuse_transpose_bias_kernel
<
float
>
(
const
float
*
inp
,
float
*
out
,
int
rows
,
int
cols
,
cudaStream_t
stream
)
{
// assert(rows % TILE_DIM == 0);
// assert(cols % TILE_DIM == 0);
dim3
grid_dim
((
cols
-
1
)
/
TILE_DIM
+
1
);
dim3
block_dim
(
TILE_DIM
,
TILE_DIM
);
column_sum_reduce
<
float
><<<
grid_dim
,
block_dim
,
0
,
stream
>>>
(
inp
,
out
,
rows
,
cols
);
}
template
<
>
void
launch_fuse_transpose_bias_kernel
<
__half
>
(
const
__half
*
inp
,
__half
*
out
,
int
rows
,
int
cols
,
cudaStream_t
stream
)
{
// assert(rows % TILE_DIM == 0);
// assert(cols % TILE_DIM == 0);
dim3
grid_dim
((
cols
-
1
)
/
TILE_DIM
+
1
);
dim3
block_dim
(
TILE_DIM
,
TILE_DIM
);
column_sum_reduce
<
__half
><<<
grid_dim
,
block_dim
,
0
,
stream
>>>
(
inp
,
out
,
rows
,
cols
);
}
__global__
void
fused_add2_kernel
(
const
int
N
,
float
*
out
,
const
float
*
inp1
,
const
float
*
inp2
)
{
const
float4
*
inp1_4
=
reinterpret_cast
<
const
float4
*>
(
inp1
);
const
float4
*
inp2_4
=
reinterpret_cast
<
const
float4
*>
(
inp2
);
float4
*
out_4
=
reinterpret_cast
<
float4
*>
(
out
);
CUDA_1D_KERNEL_LOOP
(
j
,
N
)
{
float4
val
;
float4
inp1_reg
=
inp1_4
[
j
];
float4
inp2_reg
=
inp2_4
[
j
];
val
.
x
=
inp1_reg
.
x
+
inp2_reg
.
x
;
val
.
y
=
inp1_reg
.
y
+
inp2_reg
.
y
;
val
.
z
=
inp1_reg
.
z
+
inp2_reg
.
z
;
val
.
w
=
inp1_reg
.
w
+
inp2_reg
.
w
;
out_4
[
j
]
=
val
;
}
}
__global__
void
fused_add2_kernel
(
const
int
N
,
__half
*
out
,
const
__half
*
inp1
,
const
__half
*
inp2
)
{
float2
inp1_4
;
float2
inp2_4
;
__half2
*
inp1_h
=
reinterpret_cast
<
__half2
*>
(
&
inp1_4
);
__half2
*
inp2_h
=
reinterpret_cast
<
__half2
*>
(
&
inp2_4
);
const
float2
*
inp1_arr
=
reinterpret_cast
<
const
float2
*>
(
inp1
);
const
float2
*
inp2_arr
=
reinterpret_cast
<
const
float2
*>
(
inp2
);
CUDA_1D_KERNEL_LOOP
(
j
,
N
)
{
inp1_4
=
inp1_arr
[
j
];
inp2_4
=
inp2_arr
[
j
];
float2
inp1_h_f_0
=
__half22float2
(
inp1_h
[
0
]);
float2
inp1_h_f_1
=
__half22float2
(
inp1_h
[
1
]);
float2
inp2_h_f_0
=
__half22float2
(
inp2_h
[
0
]);
float2
inp2_h_f_1
=
__half22float2
(
inp2_h
[
1
]);
inp1_h_f_0
.
x
+=
inp2_h_f_0
.
x
;
inp1_h_f_0
.
y
+=
inp2_h_f_0
.
y
;
inp1_h_f_1
.
x
+=
inp2_h_f_1
.
x
;
inp1_h_f_1
.
y
+=
inp2_h_f_1
.
y
;
float2
val_f
;
__half2
*
val_h
=
reinterpret_cast
<
__half2
*>
(
&
val_f
);
val_h
[
0
]
=
__float22half2_rn
(
inp1_h_f_0
);
val_h
[
1
]
=
__float22half2_rn
(
inp1_h_f_1
);
float2
*
out_4
=
reinterpret_cast
<
float2
*>
(
out
);
out_4
[
j
]
=
val_f
;
}
}
template
<
>
void
launch_fused_add2
<
float
>
(
float
*
out
,
const
float
*
inp1
,
const
float
*
inp2
,
int
batch_size
,
int
seq_length
,
int
hidden_dim
,
cudaStream_t
&
stream
)
{
int
total_count
=
batch_size
*
seq_length
*
hidden_dim
/
4
;
dim3
grid_dim
=
DS_GET_BLOCKS
(
total_count
);
//(batch_size * seq_length);
dim3
block_dim
=
DS_CUDA_NUM_THREADS
;
//(hidden_dim / 4);
fused_add2_kernel
<<<
grid_dim
,
block_dim
,
0
,
stream
>>>
(
total_count
,
out
,
inp1
,
inp2
);
}
template
<
>
void
launch_fused_add2
<
__half
>
(
__half
*
out
,
const
__half
*
inp1
,
const
__half
*
inp2
,
int
batch_size
,
int
seq_length
,
int
hidden_dim
,
cudaStream_t
&
stream
)
{
int
total_count
=
batch_size
*
seq_length
*
hidden_dim
/
4
;
dim3
grid_dim
=
DS_GET_BLOCKS
(
total_count
);
//(batch_size * seq_length);
dim3
block_dim
=
DS_CUDA_NUM_THREADS
;
//(hidden_dim / 4);
fused_add2_kernel
<<<
grid_dim
,
block_dim
,
0
,
stream
>>>
(
total_count
,
out
,
inp1
,
inp2
);
}
__global__
void
fused_add3_kernel
(
float
*
out
,
const
float
*
inp1
,
const
float
*
inp2
,
const
float
*
inp3
,
int
size
,
int
row_stride
)
{
int
row
=
blockIdx
.
x
;
int
id
=
threadIdx
.
x
;
const
float4
*
inp1_4
=
reinterpret_cast
<
const
float4
*>
(
inp1
);
const
float4
*
inp2_4
=
reinterpret_cast
<
const
float4
*>
(
inp2
);
const
float4
*
inp3_4
=
reinterpret_cast
<
const
float4
*>
(
inp3
);
float4
*
out_4
=
reinterpret_cast
<
float4
*>
(
out
);
float4
val
;
float4
inp1_reg
=
inp1_4
[
row
*
row_stride
+
id
];
float4
inp2_reg
=
inp2_4
[
row
*
row_stride
+
id
];
float4
inp3_reg
=
inp3_4
[
row
*
row_stride
+
id
];
val
.
x
=
inp1_reg
.
x
+
inp2_reg
.
x
+
inp3_reg
.
x
;
val
.
y
=
inp1_reg
.
y
+
inp2_reg
.
y
+
inp3_reg
.
y
;
val
.
z
=
inp1_reg
.
z
+
inp2_reg
.
z
+
inp3_reg
.
z
;
val
.
w
=
inp1_reg
.
w
+
inp2_reg
.
w
+
inp3_reg
.
w
;
out_4
[
row
*
row_stride
+
id
]
=
val
;
}
__global__
void
fused_add3_kernel
(
__half
*
out
,
const
__half
*
inp1
,
const
__half
*
inp2
,
const
__half
*
inp3
,
int
size
,
int
row_stride
)
{
int
row
=
blockIdx
.
x
;
int
id
=
threadIdx
.
x
;
const
float2
*
inp1_arr
=
reinterpret_cast
<
const
float2
*>
(
inp1
);
const
float2
*
inp2_arr
=
reinterpret_cast
<
const
float2
*>
(
inp2
);
const
float2
*
inp3_arr
=
reinterpret_cast
<
const
float2
*>
(
inp3
);
float2
inp1_4
=
inp1_arr
[
row
*
row_stride
+
id
];
float2
inp2_4
=
inp2_arr
[
row
*
row_stride
+
id
];
float2
inp3_4
=
inp3_arr
[
row
*
row_stride
+
id
];
__half2
*
inp1_h
=
reinterpret_cast
<
__half2
*>
(
&
inp1_4
);
__half2
*
inp2_h
=
reinterpret_cast
<
__half2
*>
(
&
inp2_4
);
__half2
*
inp3_h
=
reinterpret_cast
<
__half2
*>
(
&
inp3_4
);
float2
inp1_h_f_0
=
__half22float2
(
inp1_h
[
0
]);
float2
inp1_h_f_1
=
__half22float2
(
inp1_h
[
1
]);
float2
inp2_h_f_0
=
__half22float2
(
inp2_h
[
0
]);
float2
inp2_h_f_1
=
__half22float2
(
inp2_h
[
1
]);
float2
inp3_h_f_0
=
__half22float2
(
inp3_h
[
0
]);
float2
inp3_h_f_1
=
__half22float2
(
inp3_h
[
1
]);
inp1_h_f_0
.
x
+=
(
inp2_h_f_0
.
x
+
inp3_h_f_0
.
x
);
inp1_h_f_0
.
y
+=
(
inp2_h_f_0
.
y
+
inp3_h_f_0
.
y
);
inp1_h_f_1
.
x
+=
(
inp2_h_f_1
.
x
+
inp3_h_f_1
.
x
);
inp1_h_f_1
.
y
+=
(
inp2_h_f_1
.
y
+
inp3_h_f_1
.
y
);
float2
val_f
;
__half2
*
val_h
=
reinterpret_cast
<
__half2
*>
(
&
val_f
);
val_h
[
0
]
=
__float22half2_rn
(
inp1_h_f_0
);
val_h
[
1
]
=
__float22half2_rn
(
inp1_h_f_1
);
float2
*
out_4
=
reinterpret_cast
<
float2
*>
(
out
);
out_4
[
row
*
row_stride
+
id
]
=
val_f
;
}
template
<
>
void
launch_fused_add3
<
float
>
(
float
*
out
,
const
float
*
inp1
,
const
float
*
inp2
,
const
float
*
inp3
,
int
batch_size
,
int
seq_length
,
int
hidden_size
,
cudaStream_t
&
stream
)
{
dim3
grid_dim
(
batch_size
*
seq_length
);
dim3
block_dim
(
hidden_size
/
4
);
fused_add3_kernel
<<<
grid_dim
,
block_dim
,
0
,
stream
>>>
(
out
,
inp1
,
inp2
,
inp3
,
(
batch_size
*
seq_length
*
hidden_size
),
hidden_size
/
4
);
}
template
<
>
void
launch_fused_add3
<
__half
>
(
__half
*
out
,
const
__half
*
inp1
,
const
__half
*
inp2
,
const
__half
*
inp3
,
int
batch_size
,
int
seq_length
,
int
hidden_size
,
cudaStream_t
&
stream
)
{
dim3
grid_dim
(
batch_size
*
seq_length
);
dim3
block_dim
(
hidden_size
/
4
);
fused_add3_kernel
<<<
grid_dim
,
block_dim
,
0
,
stream
>>>
(
out
,
inp1
,
inp2
,
inp3
,
(
batch_size
*
seq_length
*
hidden_size
),
hidden_size
/
4
);
}
__global__
void
fused_add4_kernel
(
float
*
out
,
const
float
*
inp1
,
const
float
*
inp2
,
const
float
*
inp3
,
const
float
*
inp4
,
int
size
,
int
row_stride
)
{
int
row
=
blockIdx
.
x
;
int
id
=
threadIdx
.
x
;
const
float4
*
inp1_4
=
reinterpret_cast
<
const
float4
*>
(
inp1
);
const
float4
*
inp2_4
=
reinterpret_cast
<
const
float4
*>
(
inp2
);
const
float4
*
inp3_4
=
reinterpret_cast
<
const
float4
*>
(
inp3
);
const
float4
*
inp4_4
=
reinterpret_cast
<
const
float4
*>
(
inp4
);
float4
*
out_4
=
reinterpret_cast
<
float4
*>
(
out
);
float4
val
;
float4
inp1_reg
=
inp1_4
[
row
*
row_stride
+
id
];
float4
inp2_reg
=
inp2_4
[
row
*
row_stride
+
id
];
float4
inp3_reg
=
inp3_4
[
row
*
row_stride
+
id
];
float4
inp4_reg
=
inp4_4
[
row
*
row_stride
+
id
];
val
.
x
=
inp1_reg
.
x
+
inp2_reg
.
x
+
inp3_reg
.
x
+
inp4_reg
.
x
;
val
.
y
=
inp1_reg
.
y
+
inp2_reg
.
y
+
inp3_reg
.
y
+
inp4_reg
.
y
;
val
.
z
=
inp1_reg
.
z
+
inp2_reg
.
z
+
inp3_reg
.
z
+
inp4_reg
.
z
;
val
.
w
=
inp1_reg
.
w
+
inp2_reg
.
w
+
inp3_reg
.
w
+
inp4_reg
.
w
;
out_4
[
row
*
row_stride
+
id
]
=
val
;
}
__global__
void
fused_add4_kernel
(
__half
*
out
,
const
__half
*
inp1
,
const
__half
*
inp2
,
const
__half
*
inp3
,
const
__half
*
inp4
,
int
size
,
int
row_stride
)
{
int
row
=
blockIdx
.
x
;
int
id
=
threadIdx
.
x
;
const
float2
*
inp1_arr
=
reinterpret_cast
<
const
float2
*>
(
inp1
);
const
float2
*
inp2_arr
=
reinterpret_cast
<
const
float2
*>
(
inp2
);
const
float2
*
inp3_arr
=
reinterpret_cast
<
const
float2
*>
(
inp3
);
const
float2
*
inp4_arr
=
reinterpret_cast
<
const
float2
*>
(
inp4
);
float2
inp1_4
=
inp1_arr
[
row
*
row_stride
+
id
];
float2
inp2_4
=
inp2_arr
[
row
*
row_stride
+
id
];
float2
inp3_4
=
inp3_arr
[
row
*
row_stride
+
id
];
float2
inp4_4
=
inp4_arr
[
row
*
row_stride
+
id
];
__half2
*
inp1_h
=
reinterpret_cast
<
__half2
*>
(
&
inp1_4
);
__half2
*
inp2_h
=
reinterpret_cast
<
__half2
*>
(
&
inp2_4
);
__half2
*
inp3_h
=
reinterpret_cast
<
__half2
*>
(
&
inp3_4
);
__half2
*
inp4_h
=
reinterpret_cast
<
__half2
*>
(
&
inp4_4
);
float2
inp1_h_f_0
=
__half22float2
(
inp1_h
[
0
]);
float2
inp1_h_f_1
=
__half22float2
(
inp1_h
[
1
]);
float2
inp2_h_f_0
=
__half22float2
(
inp2_h
[
0
]);
float2
inp2_h_f_1
=
__half22float2
(
inp2_h
[
1
]);
float2
inp3_h_f_0
=
__half22float2
(
inp3_h
[
0
]);
float2
inp3_h_f_1
=
__half22float2
(
inp3_h
[
1
]);
float2
inp4_h_f_0
=
__half22float2
(
inp4_h
[
0
]);
float2
inp4_h_f_1
=
__half22float2
(
inp4_h
[
1
]);
inp1_h_f_0
.
x
+=
(
inp2_h_f_0
.
x
+
inp3_h_f_0
.
x
+
inp4_h_f_0
.
x
);
inp1_h_f_0
.
y
+=
(
inp2_h_f_0
.
y
+
inp3_h_f_0
.
y
+
inp4_h_f_0
.
y
);
inp1_h_f_1
.
x
+=
(
inp2_h_f_1
.
x
+
inp3_h_f_1
.
x
+
inp4_h_f_1
.
x
);
inp1_h_f_1
.
y
+=
(
inp2_h_f_1
.
y
+
inp3_h_f_1
.
y
+
inp4_h_f_1
.
y
);
float2
val_f
;
__half2
*
val_h
=
reinterpret_cast
<
__half2
*>
(
&
val_f
);
val_h
[
0
]
=
__float22half2_rn
(
inp1_h_f_0
);
val_h
[
1
]
=
__float22half2_rn
(
inp1_h_f_1
);
float2
*
out_4
=
reinterpret_cast
<
float2
*>
(
out
);
out_4
[
row
*
row_stride
+
id
]
=
val_f
;
}
template
<
>
void
launch_fused_add4
<
float
>
(
float
*
out
,
const
float
*
inp1
,
const
float
*
inp2
,
const
float
*
inp3
,
const
float
*
inp4
,
int
batch_size
,
int
seq_length
,
int
hidden_size
,
cudaStream_t
&
stream
)
{
dim3
grid_dim
(
batch_size
*
seq_length
);
dim3
block_dim
(
hidden_size
/
4
);
fused_add4_kernel
<<<
grid_dim
,
block_dim
,
0
,
stream
>>>
(
out
,
inp1
,
inp2
,
inp3
,
inp4
,
(
batch_size
*
seq_length
*
hidden_size
),
hidden_size
/
4
);
}
template
<
>
void
launch_fused_add4
<
__half
>
(
__half
*
out
,
const
__half
*
inp1
,
const
__half
*
inp2
,
const
__half
*
inp3
,
const
__half
*
inp4
,
int
batch_size
,
int
seq_length
,
int
hidden_size
,
cudaStream_t
&
stream
)
{
dim3
grid_dim
(
batch_size
*
seq_length
);
dim3
block_dim
(
hidden_size
/
4
);
fused_add4_kernel
<<<
grid_dim
,
block_dim
,
0
,
stream
>>>
(
out
,
inp1
,
inp2
,
inp3
,
inp4
,
(
batch_size
*
seq_length
*
hidden_size
),
hidden_size
/
4
);
}
Prev
1
2
3
4
5
6
7
8
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment