Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
deepspeed
Commits
eadbbe09
"vscode:/vscode.git/clone" did not exist on "17a6044dedf7b6e8573338e7201097f4dd837e5b"
Commit
eadbbe09
authored
Apr 25, 2021
by
401qingkong
Browse files
push rocm deepspeed v0.3.13
parent
ab5534fc
Changes
155
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
5403 additions
and
0 deletions
+5403
-0
deepspeed/ops/csrc/includes/hip/general_kernels.h
deepspeed/ops/csrc/includes/hip/general_kernels.h
+47
-0
deepspeed/ops/csrc/includes/hip/normalize_layer.h
deepspeed/ops/csrc/includes/hip/normalize_layer.h
+202
-0
deepspeed/ops/csrc/includes/hip/softmax.h
deepspeed/ops/csrc/includes/hip/softmax.h
+60
-0
deepspeed/ops/csrc/includes/hip/strided_batch_gemm.h
deepspeed/ops/csrc/includes/hip/strided_batch_gemm.h
+179
-0
deepspeed/ops/csrc/includes/hip/type_shim.h
deepspeed/ops/csrc/includes/hip/type_shim.h
+110
-0
deepspeed/ops/csrc/includes/normalize_layer.h
deepspeed/ops/csrc/includes/normalize_layer.h
+202
-0
deepspeed/ops/csrc/includes/softmax.h
deepspeed/ops/csrc/includes/softmax.h
+60
-0
deepspeed/ops/csrc/includes/strided_batch_gemm.h
deepspeed/ops/csrc/includes/strided_batch_gemm.h
+179
-0
deepspeed/ops/csrc/includes/type_shim.h
deepspeed/ops/csrc/includes/type_shim.h
+109
-0
deepspeed/ops/csrc/lamb/fused_lamb_cuda.cpp
deepspeed/ops/csrc/lamb/fused_lamb_cuda.cpp
+109
-0
deepspeed/ops/csrc/lamb/fused_lamb_cuda_kernel.cu
deepspeed/ops/csrc/lamb/fused_lamb_cuda_kernel.cu
+469
-0
deepspeed/ops/csrc/lamb/hip/fused_lamb_hip.cpp
deepspeed/ops/csrc/lamb/hip/fused_lamb_hip.cpp
+109
-0
deepspeed/ops/csrc/lamb/hip/fused_lamb_hip_kernel.hip
deepspeed/ops/csrc/lamb/hip/fused_lamb_hip_kernel.hip
+469
-0
deepspeed/ops/csrc/sparse_attention/hip/utils.cpp
deepspeed/ops/csrc/sparse_attention/hip/utils.cpp
+120
-0
deepspeed/ops/csrc/sparse_attention/utils.cpp
deepspeed/ops/csrc/sparse_attention/utils.cpp
+120
-0
deepspeed/ops/csrc/transformer/cublas_wrappers.cu
deepspeed/ops/csrc/transformer/cublas_wrappers.cu
+199
-0
deepspeed/ops/csrc/transformer/dropout_kernels.cu
deepspeed/ops/csrc/transformer/dropout_kernels.cu
+868
-0
deepspeed/ops/csrc/transformer/ds_transformer_cuda.cpp
deepspeed/ops/csrc/transformer/ds_transformer_cuda.cpp
+1046
-0
deepspeed/ops/csrc/transformer/gelu_kernels.cu
deepspeed/ops/csrc/transformer/gelu_kernels.cu
+335
-0
deepspeed/ops/csrc/transformer/general_kernels.cu
deepspeed/ops/csrc/transformer/general_kernels.cu
+411
-0
No files found.
deepspeed/ops/csrc/includes/hip/general_kernels.h
0 → 100644
View file @
eadbbe09
#include <hip/hip_runtime.h>
#include <hip/hip_fp16.h>
#include <stdio.h>
#include <stdlib.h>
#include <cooperative_groups.h>
#include <hiprand_kernel.h>
#include "context.h"
#include "cublas_wrappers.h"
#define THREADS 256
#define TILE_DIM 32
#define minus_infinity -1 * std::numeric_limits<float>::infinity()
#define FINAL_MASK 0xffffffff
template
<
typename
T
>
void
launch_fused_add2
(
T
*
out
,
const
T
*
inp1
,
const
T
*
inp2
,
int
batch_size
,
int
seq_length
,
int
hidden_size
,
hipStream_t
&
stream
);
template
<
typename
T
>
void
launch_fused_add4
(
T
*
out
,
const
T
*
inp1
,
const
T
*
inp2
,
const
T
*
inp3
,
const
T
*
inp4
,
int
batch_size
,
int
seq_length
,
int
hidden_size
,
hipStream_t
&
stream
);
template
<
typename
T
>
void
launch_fused_add3
(
T
*
out
,
const
T
*
inp1
,
const
T
*
inp2
,
const
T
*
inp3
,
int
batch_size
,
int
seq_length
,
int
hidden_size
,
hipStream_t
&
stream
);
deepspeed/ops/csrc/includes/hip/normalize_layer.h
0 → 100644
View file @
eadbbe09
#pragma once
#include <hip/hip_runtime.h>
#include <hip/hip_fp16.h>
#include <stdio.h>
#include <fstream>
#include "custom_cuda_layers.h"
using
namespace
std
;
template
<
typename
T
>
class
Normalize_Layer
{
public:
struct
Config
{
uint32_t
batchSize
;
uint32_t
seqLength
;
uint32_t
hiddenDim
;
float
epsilon
;
bool
training
;
bool
useMean
;
Config
(
uint32_t
batch
,
uint32_t
seq
,
uint32_t
h
,
float
epsilon
=
1e-12
,
bool
training
=
true
,
bool
useMean
=
true
)
:
batchSize
(
batch
),
seqLength
(
seq
),
hiddenDim
(
h
),
epsilon
(
epsilon
),
training
(
training
),
useMean
(
useMean
)
{
}
};
Normalize_Layer
(
Config
config
)
:
config_
(
config
),
vars
(
nullptr
),
means
(
nullptr
),
vals_hat
(
nullptr
)
{
}
~
Normalize_Layer
()
{}
void
ForwardCheckpoint
(
int
bsz
,
// batch * seq
T
*
vals
,
const
T
*
residual
,
const
T
*
gamma
,
const
T
*
betta
,
hipStream_t
&
stream
,
bool
preLayerNorm
=
false
)
{
launch_bias_residual_layer_norm
(
vals
,
residual
,
gamma
,
betta
,
config_
.
epsilon
,
bsz
,
config_
.
hiddenDim
,
stream
,
preLayerNorm
,
config_
.
training
,
vars
,
means
);
}
void
Forward
(
int
bsz
,
T
*
vals
,
const
T
*
residual
,
const
T
*
gamma
,
const
T
*
betta
,
hipStream_t
&
stream
,
bool
preLayerNorm
=
false
)
{
launch_bias_residual_layer_norm
(
vals
,
residual
,
gamma
,
betta
,
config_
.
epsilon
,
bsz
,
config_
.
hiddenDim
,
stream
,
preLayerNorm
,
config_
.
training
,
vars
);
}
void
Backward
(
int
bsz
,
const
T
*
out_grad
,
const
T
*
gamma
,
T
*
gamma_grad
,
T
*
betta_grad
,
hipStream_t
stream
[
2
],
T
*
inp_grad_out
,
const
T
*
norm_in
=
nullptr
)
{
launch_layerNorm_backward
(
out_grad
,
norm_in
,
vars
,
means
,
gamma
,
gamma_grad
,
betta_grad
,
inp_grad_out
,
bsz
,
config_
.
hiddenDim
,
stream
);
}
void
Backward
(
int
bsz
,
const
T
*
out_grad
,
const
T
*
gamma
,
const
T
*
betta
,
T
*
gamma_grad
,
T
*
betta_grad
,
hipStream_t
stream
[
2
],
T
*
inp_grad_out
,
const
T
*
norm_out
)
{
launch_layerNorm_backward
(
out_grad
,
norm_out
,
vars
,
gamma
,
gamma_grad
,
betta_grad
,
inp_grad_out
,
bsz
,
config_
.
hiddenDim
,
stream
,
!
config_
.
useMean
,
betta
);
}
void
BackwardFusedAdd
(
int
bsz
,
const
T
*
out_grad1
,
const
T
*
out_grad2
,
const
T
*
gamma
,
T
*
gamma_grad
,
T
*
betta_grad
,
hipStream_t
stream
[
2
],
T
*
inp_grad_out
,
const
T
*
norm_in
=
nullptr
)
{
launch_layerNorm_backward_fused_add
(
out_grad1
,
out_grad2
,
norm_in
,
vars
,
means
,
gamma
,
gamma_grad
,
betta_grad
,
inp_grad_out
,
bsz
,
config_
.
hiddenDim
,
stream
);
}
void
BackwardFusedAdd
(
int
bsz
,
const
T
*
out_grad1
,
const
T
*
out_grad2
,
const
T
*
gamma
,
const
T
*
betta
,
T
*
gamma_grad
,
T
*
betta_grad
,
hipStream_t
stream
[
2
],
T
*
inp_grad_out
,
const
T
*
norm_out
)
{
launch_layerNorm_backward_fused_add
(
out_grad1
,
out_grad2
,
norm_out
,
vars
,
gamma
,
gamma_grad
,
betta_grad
,
inp_grad_out
,
bsz
,
config_
.
hiddenDim
,
stream
,
!
config_
.
useMean
,
betta
);
}
inline
bool
UseMean
()
const
{
return
config_
.
useMean
;
}
inline
void
SetVar
(
T
*
variance
)
{
if
(
!
variance
)
{
throw
std
::
runtime_error
(
"Normalize variance is null."
);
}
vars
=
variance
;
}
inline
void
SetMean
(
T
*
mean
)
{
if
(
!
mean
)
{
throw
std
::
runtime_error
(
"Normalize mean is null."
);
}
means
=
mean
;
}
private:
Config
config_
;
T
*
vars
;
T
*
means
;
T
*
vals_hat
;
};
deepspeed/ops/csrc/includes/hip/softmax.h
0 → 100644
View file @
eadbbe09
#pragma once
#include <hip/hip_runtime.h>
#include <hip/hip_fp16.h>
#include <stdio.h>
#include "custom_cuda_layers.h"
#include <fstream>
using
namespace
std
;
template
<
typename
T
>
class
Softmax
{
public:
struct
Config
{
size_t
batchSize
;
size_t
heads
;
size_t
seq_length
;
size_t
prob_depth
;
float
temprature
;
bool
mem_alloc
;
Config
(
size_t
batch
,
size_t
h
,
size_t
seq
,
int
prob_size
=
0
,
bool
mem_alloc
=
false
)
:
batchSize
(
batch
),
heads
(
h
),
seq_length
(
seq
),
prob_depth
(
prob_size
),
temprature
(
1.0
),
mem_alloc
(
mem_alloc
)
{
}
};
Softmax
(
Config
config
)
:
config_
(
config
)
{}
~
Softmax
()
{}
void
Forward
(
int
bsz
,
T
*
vals
,
const
T
*
attn_mask
,
hipStream_t
&
stream
)
{
launch_attn_softmax
<
T
>
(
vals
,
attn_mask
,
bsz
,
config_
.
heads
,
config_
.
seq_length
,
stream
);
}
void
Backward
(
int
bsz
,
T
*
out_grad
,
const
T
*
soft_out
,
hipStream_t
stream
)
{
launch_attn_softmax_backward_v2
<
T
>
(
out_grad
,
soft_out
,
bsz
,
config_
.
heads
,
config_
.
seq_length
,
stream
);
}
inline
size_t
GetProbDepth
()
const
{
return
config_
.
prob_depth
;
}
inline
size_t
GetBatchSize
()
const
{
return
config_
.
batchSize
;
}
inline
size_t
GetNumHeads
()
const
{
return
config_
.
heads
;
}
inline
size_t
GetSeqLength
()
const
{
return
config_
.
seq_length
;
}
inline
void
SetSeqLength
(
size_t
seq_len
)
{
config_
.
seq_length
=
seq_len
;
}
private:
Config
config_
;
};
deepspeed/ops/csrc/includes/hip/strided_batch_gemm.h
0 → 100644
View file @
eadbbe09
#pragma once
#include <hip/hip_runtime.h>
#include <hip/hip_fp16.h>
#include <stdio.h>
#include "context.h"
template
<
typename
T
>
class
StridedBatchGemm
{
public:
struct
Config
{
int
batch_size
;
int
m
;
int
n
;
int
k
;
float
alpha
;
float
beta
;
rocblas_operation
op_A
;
rocblas_operation
op_B
;
std
::
array
<
int
,
3
>
gemm_algos
;
Config
(
int
batch
,
int
mm
,
int
nn
,
int
kk
,
float
param_alpha
,
float
param_beta
,
rocblas_operation
opA
,
rocblas_operation
opB
,
const
std
::
array
<
int
,
3
>&
algos
)
:
batch_size
(
batch
),
m
(
mm
),
n
(
nn
),
k
(
kk
),
alpha
(
param_alpha
),
beta
(
param_beta
),
op_A
(
opA
),
op_B
(
opB
),
gemm_algos
(
algos
)
{
}
void
SetConfig
(
int
mm
,
int
nn
,
int
kk
)
{
m
=
mm
;
n
=
nn
;
k
=
kk
;
}
};
StridedBatchGemm
(
const
Config
&
config
)
:
_config
(
config
)
{}
virtual
~
StridedBatchGemm
()
{}
void
Forward
(
int
bsz
,
T
*
output
,
const
T
*
_buffer_a
,
const
T
*
_buffer_b
,
rocblas_handle
handle
)
{
int
stride_a
=
_config
.
m
*
_config
.
k
;
int
stride_b
=
_config
.
n
*
_config
.
k
;
int
stride_c
=
_config
.
m
*
_config
.
n
;
cublas_strided_batched_gemm
(
handle
,
_config
.
m
,
_config
.
n
,
_config
.
k
,
&
_config
.
alpha
,
&
_config
.
beta
,
_buffer_a
,
_buffer_b
,
output
,
_config
.
op_A
,
_config
.
op_B
,
stride_a
,
stride_b
,
stride_c
,
bsz
,
cublasGemmAlgo_t
(
_config
.
gemm_algos
[
0
]));
}
void
ForwardPlusSave
(
T
*
output
,
const
T
*
_buffer_a
,
const
T
*
_buffer_b
,
rocblas_handle
handle
)
{
int
stride_a
=
_config
.
m
*
_config
.
k
;
int
stride_b
=
_config
.
n
*
_config
.
k
;
int
stride_c
=
_config
.
m
*
_config
.
n
;
cublas_strided_batched_gemm
(
handle
,
_config
.
m
,
_config
.
n
,
_config
.
k
,
&
_config
.
alpha
,
&
_config
.
beta
,
_buffer_a
,
_buffer_b
,
output
,
_config
.
op_A
,
_config
.
op_B
,
stride_a
,
stride_b
,
stride_c
,
_config
.
batch_size
,
cublasGemmAlgo_t
(
_config
.
gemm_algos
[
0
]));
k_buf
=
_buffer_a
;
q_buf
=
_buffer_b
;
}
void
Backward
(
int
bsz
,
const
T
*
d_output
,
const
T
*
_buffer_a
,
const
T
*
_buffer_b
,
rocblas_handle
handle
,
T
*
inpGradA
=
nullptr
,
T
*
inpGradB
=
nullptr
)
{
int
mb
=
(
_config
.
op_A
==
rocblas_operation_transpose
?
_config
.
k
:
_config
.
m
);
int
kb
=
(
_config
.
op_A
==
rocblas_operation_transpose
?
_config
.
m
:
_config
.
k
);
int
stride_a
=
mb
*
_config
.
n
;
int
stride_b
=
_config
.
n
*
kb
;
int
stride_c
=
_config
.
m
*
_config
.
k
;
// B need to transpose.
rocblas_operation
op_b
=
(
_config
.
op_B
==
rocblas_operation_transpose
?
rocblas_operation_none
:
rocblas_operation_transpose
);
// Calculate d_A.
cublas_strided_batched_gemm
(
handle
,
mb
,
kb
,
_config
.
n
,
&
_config
.
alpha
,
&
_config
.
beta
,
(
_config
.
op_A
==
rocblas_operation_transpose
?
_buffer_b
:
d_output
),
(
_config
.
op_A
==
rocblas_operation_transpose
?
d_output
:
_buffer_b
),
inpGradA
,
rocblas_operation_none
,
op_b
,
stride_a
,
stride_b
,
stride_c
,
bsz
,
cublasGemmAlgo_t
(
_config
.
gemm_algos
[
1
]));
// A need to transpose.
rocblas_operation
op_a
=
(
_config
.
op_A
==
rocblas_operation_transpose
?
rocblas_operation_none
:
rocblas_operation_transpose
);
stride_a
=
_config
.
m
*
_config
.
k
;
stride_b
=
_config
.
m
*
_config
.
n
;
stride_c
=
_config
.
n
*
_config
.
k
;
// Calculate d_B.
cublas_strided_batched_gemm
(
handle
,
_config
.
k
,
_config
.
n
,
_config
.
m
,
&
_config
.
alpha
,
&
_config
.
beta
,
_buffer_a
,
d_output
,
inpGradB
,
op_a
,
rocblas_operation_none
,
stride_a
,
stride_b
,
stride_c
,
bsz
,
cublasGemmAlgo_t
(
_config
.
gemm_algos
[
2
]));
}
inline
int
GetN
()
const
{
return
_config
.
k
;
}
inline
const
T
*
GetBufferA
()
const
{
return
k_buf
;
}
inline
const
T
*
GetBufferB
()
const
{
return
q_buf
;
}
inline
void
SetConfig
(
int
m
,
int
n
,
int
k
)
{
_config
.
SetConfig
(
m
,
n
,
k
);
}
private:
Config
_config
;
const
T
*
q_buf
;
const
T
*
k_buf
;
};
deepspeed/ops/csrc/includes/hip/type_shim.h
0 → 100644
View file @
eadbbe09
#include "hip/hip_runtime.h"
/* Taken from NVIDIA/apex commit 855808f3fc268e9715d613f3c2e56469d8c986d8 */
#include <ATen/ATen.h>
// Forward/backward compatiblity hack around
// https://github.com/pytorch/pytorch/commit/3aeb78079bcd68282fe9117088e138b77318e288
// pending more future-proof guidance from upstream.
// struct TypeShim
// {
// const at::Type& payload;
// TypeShim(const at::Type& type) : payload(type) {}
// // Enable trivial conversion to a const at::Type& for pre-3aeb78
// operator const at::Type&(){ return payload; };
// // Enable dispatch switch statements to take *this directly for post-3aeb78
// //operator at::ScalarType(){ return payload.; };
// };
#define DISPATCH_FLOAT_AND_HALF(TYPE, LEVEL, NAME, ...) \
switch (TYPE) { \
case at::ScalarType::Float: { \
using scalar_t_##LEVEL = float; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Half: { \
using scalar_t_##LEVEL = at::Half; \
__VA_ARGS__; \
break; \
} \
default: AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
}
#define DISPATCH_DOUBLE_FLOAT_AND_HALF(TYPE, LEVEL, NAME, ...) \
switch (TYPE) { \
case at::ScalarType::Double: { \
using scalar_t_##LEVEL = double; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Float: { \
using scalar_t_##LEVEL = float; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Half: { \
using scalar_t_##LEVEL = at::Half; \
__VA_ARGS__; \
break; \
} \
default: AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
}
#define DISPATCH_DOUBLE_AND_FLOAT(TYPE, LEVEL, NAME, ...) \
switch (TYPE) { \
case at::ScalarType::Double: { \
using scalar_t_##LEVEL = double; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Float: { \
using scalar_t_##LEVEL = float; \
__VA_ARGS__; \
break; \
} \
default: AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
}
template
<
typename
T
>
__device__
__forceinline__
T
reduce_block_into_lanes
(
T
*
x
,
T
val
,
int
lanes
=
1
,
bool
share_result
=
false
)
// lanes is intended to be <= 32.
{
int
tid
=
threadIdx
.
x
+
threadIdx
.
y
*
blockDim
.
x
;
int
blockSize
=
blockDim
.
x
*
blockDim
.
y
;
// blockSize is intended to be a multiple of 32.
if
(
blockSize
>=
64
)
{
x
[
tid
]
=
val
;
__syncthreads
();
}
#pragma unroll
for
(
int
i
=
(
blockSize
>>
1
);
i
>=
64
;
i
>>=
1
)
{
if
(
tid
<
i
)
x
[
tid
]
=
x
[
tid
]
+
x
[
tid
+
i
];
__syncthreads
();
}
T
final
;
if
(
tid
<
32
)
{
if
(
blockSize
>=
64
)
final
=
x
[
tid
]
+
x
[
tid
+
32
];
else
final
=
val
;
// __SYNCWARP();
#pragma unroll
for
(
int
i
=
16
;
i
>=
lanes
;
i
>>=
1
)
final
=
final
+
__shfl_down_sync
(
0xffffffff
,
final
,
i
);
}
if
(
share_result
)
{
if
(
tid
<
lanes
)
x
[
tid
]
=
final
;
// EpilogueOp
// Make sure the smem result is visible to all warps.
__syncthreads
();
}
return
final
;
}
deepspeed/ops/csrc/includes/normalize_layer.h
0 → 100644
View file @
eadbbe09
#pragma once
#include <cuda.h>
#include <cuda_fp16.h>
#include <stdio.h>
#include <fstream>
#include "custom_cuda_layers.h"
using
namespace
std
;
template
<
typename
T
>
class
Normalize_Layer
{
public:
struct
Config
{
uint32_t
batchSize
;
uint32_t
seqLength
;
uint32_t
hiddenDim
;
float
epsilon
;
bool
training
;
bool
useMean
;
Config
(
uint32_t
batch
,
uint32_t
seq
,
uint32_t
h
,
float
epsilon
=
1e-12
,
bool
training
=
true
,
bool
useMean
=
true
)
:
batchSize
(
batch
),
seqLength
(
seq
),
hiddenDim
(
h
),
epsilon
(
epsilon
),
training
(
training
),
useMean
(
useMean
)
{
}
};
Normalize_Layer
(
Config
config
)
:
config_
(
config
),
vars
(
nullptr
),
means
(
nullptr
),
vals_hat
(
nullptr
)
{
}
~
Normalize_Layer
()
{}
void
ForwardCheckpoint
(
int
bsz
,
// batch * seq
T
*
vals
,
const
T
*
residual
,
const
T
*
gamma
,
const
T
*
betta
,
cudaStream_t
&
stream
,
bool
preLayerNorm
=
false
)
{
launch_bias_residual_layer_norm
(
vals
,
residual
,
gamma
,
betta
,
config_
.
epsilon
,
bsz
,
config_
.
hiddenDim
,
stream
,
preLayerNorm
,
config_
.
training
,
vars
,
means
);
}
void
Forward
(
int
bsz
,
T
*
vals
,
const
T
*
residual
,
const
T
*
gamma
,
const
T
*
betta
,
cudaStream_t
&
stream
,
bool
preLayerNorm
=
false
)
{
launch_bias_residual_layer_norm
(
vals
,
residual
,
gamma
,
betta
,
config_
.
epsilon
,
bsz
,
config_
.
hiddenDim
,
stream
,
preLayerNorm
,
config_
.
training
,
vars
);
}
void
Backward
(
int
bsz
,
const
T
*
out_grad
,
const
T
*
gamma
,
T
*
gamma_grad
,
T
*
betta_grad
,
cudaStream_t
stream
[
2
],
T
*
inp_grad_out
,
const
T
*
norm_in
=
nullptr
)
{
launch_layerNorm_backward
(
out_grad
,
norm_in
,
vars
,
means
,
gamma
,
gamma_grad
,
betta_grad
,
inp_grad_out
,
bsz
,
config_
.
hiddenDim
,
stream
);
}
void
Backward
(
int
bsz
,
const
T
*
out_grad
,
const
T
*
gamma
,
const
T
*
betta
,
T
*
gamma_grad
,
T
*
betta_grad
,
cudaStream_t
stream
[
2
],
T
*
inp_grad_out
,
const
T
*
norm_out
)
{
launch_layerNorm_backward
(
out_grad
,
norm_out
,
vars
,
gamma
,
gamma_grad
,
betta_grad
,
inp_grad_out
,
bsz
,
config_
.
hiddenDim
,
stream
,
!
config_
.
useMean
,
betta
);
}
void
BackwardFusedAdd
(
int
bsz
,
const
T
*
out_grad1
,
const
T
*
out_grad2
,
const
T
*
gamma
,
T
*
gamma_grad
,
T
*
betta_grad
,
cudaStream_t
stream
[
2
],
T
*
inp_grad_out
,
const
T
*
norm_in
=
nullptr
)
{
launch_layerNorm_backward_fused_add
(
out_grad1
,
out_grad2
,
norm_in
,
vars
,
means
,
gamma
,
gamma_grad
,
betta_grad
,
inp_grad_out
,
bsz
,
config_
.
hiddenDim
,
stream
);
}
void
BackwardFusedAdd
(
int
bsz
,
const
T
*
out_grad1
,
const
T
*
out_grad2
,
const
T
*
gamma
,
const
T
*
betta
,
T
*
gamma_grad
,
T
*
betta_grad
,
cudaStream_t
stream
[
2
],
T
*
inp_grad_out
,
const
T
*
norm_out
)
{
launch_layerNorm_backward_fused_add
(
out_grad1
,
out_grad2
,
norm_out
,
vars
,
gamma
,
gamma_grad
,
betta_grad
,
inp_grad_out
,
bsz
,
config_
.
hiddenDim
,
stream
,
!
config_
.
useMean
,
betta
);
}
inline
bool
UseMean
()
const
{
return
config_
.
useMean
;
}
inline
void
SetVar
(
T
*
variance
)
{
if
(
!
variance
)
{
throw
std
::
runtime_error
(
"Normalize variance is null."
);
}
vars
=
variance
;
}
inline
void
SetMean
(
T
*
mean
)
{
if
(
!
mean
)
{
throw
std
::
runtime_error
(
"Normalize mean is null."
);
}
means
=
mean
;
}
private:
Config
config_
;
T
*
vars
;
T
*
means
;
T
*
vals_hat
;
};
deepspeed/ops/csrc/includes/softmax.h
0 → 100755
View file @
eadbbe09
#pragma once
#include <cuda.h>
#include <cuda_fp16.h>
#include <stdio.h>
#include "custom_cuda_layers.h"
#include <fstream>
using
namespace
std
;
template
<
typename
T
>
class
Softmax
{
public:
struct
Config
{
size_t
batchSize
;
size_t
heads
;
size_t
seq_length
;
size_t
prob_depth
;
float
temprature
;
bool
mem_alloc
;
Config
(
size_t
batch
,
size_t
h
,
size_t
seq
,
int
prob_size
=
0
,
bool
mem_alloc
=
false
)
:
batchSize
(
batch
),
heads
(
h
),
seq_length
(
seq
),
prob_depth
(
prob_size
),
temprature
(
1.0
),
mem_alloc
(
mem_alloc
)
{
}
};
Softmax
(
Config
config
)
:
config_
(
config
)
{}
~
Softmax
()
{}
void
Forward
(
int
bsz
,
T
*
vals
,
const
T
*
attn_mask
,
cudaStream_t
&
stream
)
{
launch_attn_softmax
<
T
>
(
vals
,
attn_mask
,
bsz
,
config_
.
heads
,
config_
.
seq_length
,
stream
);
}
void
Backward
(
int
bsz
,
T
*
out_grad
,
const
T
*
soft_out
,
cudaStream_t
stream
)
{
launch_attn_softmax_backward_v2
<
T
>
(
out_grad
,
soft_out
,
bsz
,
config_
.
heads
,
config_
.
seq_length
,
stream
);
}
inline
size_t
GetProbDepth
()
const
{
return
config_
.
prob_depth
;
}
inline
size_t
GetBatchSize
()
const
{
return
config_
.
batchSize
;
}
inline
size_t
GetNumHeads
()
const
{
return
config_
.
heads
;
}
inline
size_t
GetSeqLength
()
const
{
return
config_
.
seq_length
;
}
inline
void
SetSeqLength
(
size_t
seq_len
)
{
config_
.
seq_length
=
seq_len
;
}
private:
Config
config_
;
};
deepspeed/ops/csrc/includes/strided_batch_gemm.h
0 → 100644
View file @
eadbbe09
#pragma once
#include <cuda.h>
#include <cuda_fp16.h>
#include <stdio.h>
#include "context.h"
template
<
typename
T
>
class
StridedBatchGemm
{
public:
struct
Config
{
int
batch_size
;
int
m
;
int
n
;
int
k
;
float
alpha
;
float
beta
;
cublasOperation_t
op_A
;
cublasOperation_t
op_B
;
std
::
array
<
int
,
3
>
gemm_algos
;
Config
(
int
batch
,
int
mm
,
int
nn
,
int
kk
,
float
param_alpha
,
float
param_beta
,
cublasOperation_t
opA
,
cublasOperation_t
opB
,
const
std
::
array
<
int
,
3
>&
algos
)
:
batch_size
(
batch
),
m
(
mm
),
n
(
nn
),
k
(
kk
),
alpha
(
param_alpha
),
beta
(
param_beta
),
op_A
(
opA
),
op_B
(
opB
),
gemm_algos
(
algos
)
{
}
void
SetConfig
(
int
mm
,
int
nn
,
int
kk
)
{
m
=
mm
;
n
=
nn
;
k
=
kk
;
}
};
StridedBatchGemm
(
const
Config
&
config
)
:
_config
(
config
)
{}
virtual
~
StridedBatchGemm
()
{}
void
Forward
(
int
bsz
,
T
*
output
,
const
T
*
_buffer_a
,
const
T
*
_buffer_b
,
cublasHandle_t
handle
)
{
int
stride_a
=
_config
.
m
*
_config
.
k
;
int
stride_b
=
_config
.
n
*
_config
.
k
;
int
stride_c
=
_config
.
m
*
_config
.
n
;
cublas_strided_batched_gemm
(
handle
,
_config
.
m
,
_config
.
n
,
_config
.
k
,
&
_config
.
alpha
,
&
_config
.
beta
,
_buffer_a
,
_buffer_b
,
output
,
_config
.
op_A
,
_config
.
op_B
,
stride_a
,
stride_b
,
stride_c
,
bsz
,
cublasGemmAlgo_t
(
_config
.
gemm_algos
[
0
]));
}
void
ForwardPlusSave
(
T
*
output
,
const
T
*
_buffer_a
,
const
T
*
_buffer_b
,
cublasHandle_t
handle
)
{
int
stride_a
=
_config
.
m
*
_config
.
k
;
int
stride_b
=
_config
.
n
*
_config
.
k
;
int
stride_c
=
_config
.
m
*
_config
.
n
;
cublas_strided_batched_gemm
(
handle
,
_config
.
m
,
_config
.
n
,
_config
.
k
,
&
_config
.
alpha
,
&
_config
.
beta
,
_buffer_a
,
_buffer_b
,
output
,
_config
.
op_A
,
_config
.
op_B
,
stride_a
,
stride_b
,
stride_c
,
_config
.
batch_size
,
cublasGemmAlgo_t
(
_config
.
gemm_algos
[
0
]));
k_buf
=
_buffer_a
;
q_buf
=
_buffer_b
;
}
void
Backward
(
int
bsz
,
const
T
*
d_output
,
const
T
*
_buffer_a
,
const
T
*
_buffer_b
,
cublasHandle_t
handle
,
T
*
inpGradA
=
nullptr
,
T
*
inpGradB
=
nullptr
)
{
int
mb
=
(
_config
.
op_A
==
CUBLAS_OP_T
?
_config
.
k
:
_config
.
m
);
int
kb
=
(
_config
.
op_A
==
CUBLAS_OP_T
?
_config
.
m
:
_config
.
k
);
int
stride_a
=
mb
*
_config
.
n
;
int
stride_b
=
_config
.
n
*
kb
;
int
stride_c
=
_config
.
m
*
_config
.
k
;
// B need to transpose.
cublasOperation_t
op_b
=
(
_config
.
op_B
==
CUBLAS_OP_T
?
CUBLAS_OP_N
:
CUBLAS_OP_T
);
// Calculate d_A.
cublas_strided_batched_gemm
(
handle
,
mb
,
kb
,
_config
.
n
,
&
_config
.
alpha
,
&
_config
.
beta
,
(
_config
.
op_A
==
CUBLAS_OP_T
?
_buffer_b
:
d_output
),
(
_config
.
op_A
==
CUBLAS_OP_T
?
d_output
:
_buffer_b
),
inpGradA
,
CUBLAS_OP_N
,
op_b
,
stride_a
,
stride_b
,
stride_c
,
bsz
,
cublasGemmAlgo_t
(
_config
.
gemm_algos
[
1
]));
// A need to transpose.
cublasOperation_t
op_a
=
(
_config
.
op_A
==
CUBLAS_OP_T
?
CUBLAS_OP_N
:
CUBLAS_OP_T
);
stride_a
=
_config
.
m
*
_config
.
k
;
stride_b
=
_config
.
m
*
_config
.
n
;
stride_c
=
_config
.
n
*
_config
.
k
;
// Calculate d_B.
cublas_strided_batched_gemm
(
handle
,
_config
.
k
,
_config
.
n
,
_config
.
m
,
&
_config
.
alpha
,
&
_config
.
beta
,
_buffer_a
,
d_output
,
inpGradB
,
op_a
,
CUBLAS_OP_N
,
stride_a
,
stride_b
,
stride_c
,
bsz
,
cublasGemmAlgo_t
(
_config
.
gemm_algos
[
2
]));
}
inline
int
GetN
()
const
{
return
_config
.
k
;
}
inline
const
T
*
GetBufferA
()
const
{
return
k_buf
;
}
inline
const
T
*
GetBufferB
()
const
{
return
q_buf
;
}
inline
void
SetConfig
(
int
m
,
int
n
,
int
k
)
{
_config
.
SetConfig
(
m
,
n
,
k
);
}
private:
Config
_config
;
const
T
*
q_buf
;
const
T
*
k_buf
;
};
deepspeed/ops/csrc/includes/type_shim.h
0 → 100644
View file @
eadbbe09
/* Taken from NVIDIA/apex commit 855808f3fc268e9715d613f3c2e56469d8c986d8 */
#include <ATen/ATen.h>
// Forward/backward compatiblity hack around
// https://github.com/pytorch/pytorch/commit/3aeb78079bcd68282fe9117088e138b77318e288
// pending more future-proof guidance from upstream.
// struct TypeShim
// {
// const at::Type& payload;
// TypeShim(const at::Type& type) : payload(type) {}
// // Enable trivial conversion to a const at::Type& for pre-3aeb78
// operator const at::Type&(){ return payload; };
// // Enable dispatch switch statements to take *this directly for post-3aeb78
// //operator at::ScalarType(){ return payload.; };
// };
#define DISPATCH_FLOAT_AND_HALF(TYPE, LEVEL, NAME, ...) \
switch (TYPE) { \
case at::ScalarType::Float: { \
using scalar_t_##LEVEL = float; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Half: { \
using scalar_t_##LEVEL = at::Half; \
__VA_ARGS__; \
break; \
} \
default: AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
}
#define DISPATCH_DOUBLE_FLOAT_AND_HALF(TYPE, LEVEL, NAME, ...) \
switch (TYPE) { \
case at::ScalarType::Double: { \
using scalar_t_##LEVEL = double; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Float: { \
using scalar_t_##LEVEL = float; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Half: { \
using scalar_t_##LEVEL = at::Half; \
__VA_ARGS__; \
break; \
} \
default: AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
}
#define DISPATCH_DOUBLE_AND_FLOAT(TYPE, LEVEL, NAME, ...) \
switch (TYPE) { \
case at::ScalarType::Double: { \
using scalar_t_##LEVEL = double; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Float: { \
using scalar_t_##LEVEL = float; \
__VA_ARGS__; \
break; \
} \
default: AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
}
template
<
typename
T
>
__device__
__forceinline__
T
reduce_block_into_lanes
(
T
*
x
,
T
val
,
int
lanes
=
1
,
bool
share_result
=
false
)
// lanes is intended to be <= 32.
{
int
tid
=
threadIdx
.
x
+
threadIdx
.
y
*
blockDim
.
x
;
int
blockSize
=
blockDim
.
x
*
blockDim
.
y
;
// blockSize is intended to be a multiple of 32.
if
(
blockSize
>=
64
)
{
x
[
tid
]
=
val
;
__syncthreads
();
}
#pragma unroll
for
(
int
i
=
(
blockSize
>>
1
);
i
>=
64
;
i
>>=
1
)
{
if
(
tid
<
i
)
x
[
tid
]
=
x
[
tid
]
+
x
[
tid
+
i
];
__syncthreads
();
}
T
final
;
if
(
tid
<
32
)
{
if
(
blockSize
>=
64
)
final
=
x
[
tid
]
+
x
[
tid
+
32
];
else
final
=
val
;
// __SYNCWARP();
#pragma unroll
for
(
int
i
=
16
;
i
>=
lanes
;
i
>>=
1
)
final
=
final
+
__shfl_down_sync
(
0xffffffff
,
final
,
i
);
}
if
(
share_result
)
{
if
(
tid
<
lanes
)
x
[
tid
]
=
final
;
// EpilogueOp
// Make sure the smem result is visible to all warps.
__syncthreads
();
}
return
final
;
}
deepspeed/ops/csrc/lamb/fused_lamb_cuda.cpp
0 → 100644
View file @
eadbbe09
/* Copyright 2019 The Microsoft DeepSpeed Team */
#include <torch/extension.h>
// CUDA forward declaration
void
fused_lamb_cuda
(
at
::
Tensor
&
p
,
at
::
Tensor
&
p_copy
,
at
::
Tensor
&
m
,
at
::
Tensor
&
v
,
at
::
Tensor
&
g
,
float
lr
,
float
beta1
,
float
beta2
,
float
max_coeff
,
float
min_coeff
,
float
eps
,
float
grad_scale
,
int
step
,
int
mode
,
int
bias_correction
,
float
decay
,
at
::
Tensor
&
w_l2_i
,
at
::
Tensor
&
u_l2_i
,
at
::
Tensor
&
lamb_coeff_val
);
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) \
CHECK_CUDA(x); \
CHECK_CONTIGUOUS(x)
// C++ interface
at
::
Tensor
lamb
(
at
::
Tensor
&
p
,
at
::
Tensor
&
p_copy
,
at
::
Tensor
&
m
,
at
::
Tensor
&
v
,
at
::
Tensor
&
g
,
float
lr
,
float
beta1
,
float
beta2
,
float
max_coeff
,
float
min_coeff
,
float
eps
,
float
grad_scale
,
int
step
,
int
mode
,
int
bias_correction
,
float
decay
)
{
CHECK_INPUT
(
p
);
if
(
p_copy
.
numel
()
>
0
)
CHECK_INPUT
(
p_copy
);
CHECK_INPUT
(
m
);
CHECK_INPUT
(
v
);
CHECK_INPUT
(
g
);
int64_t
num_elem
=
p
.
numel
();
AT_ASSERTM
(
m
.
numel
()
==
num_elem
,
"number of elements in m and p tensors should be equal"
);
AT_ASSERTM
(
v
.
numel
()
==
num_elem
,
"number of elements in v and p tensors should be equal"
);
AT_ASSERTM
(
g
.
numel
()
==
num_elem
,
"number of elements in g and p tensors should be equal"
);
AT_ASSERTM
(
p_copy
.
numel
()
==
num_elem
||
p_copy
.
numel
()
==
0
,
"number of elements in p_copy and p tensors should be equal, or p_copy should be empty"
);
// intermediate for weight L2 reduction
// make sure that the threads per block is at least 512 during the kernel launch otherwise the
// behavious is unexpected
at
::
Tensor
w_l2_i
=
at
::
empty
(
{
512
},
p
.
options
().
dtype
(
p
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
?
at
::
ScalarType
::
Float
:
p
.
type
().
scalarType
()));
// intermediate for update L2 reduction
// make sure that the threads per block is at least 512 during the kernel launch otherwise the
// behavious is unexpected
at
::
Tensor
u_l2_i
=
at
::
empty
(
{
512
},
p
.
options
().
dtype
(
p
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
?
at
::
ScalarType
::
Float
:
p
.
type
().
scalarType
()));
at
::
Tensor
lamb_coeff_val
=
at
::
empty
(
{
1
},
p
.
options
().
dtype
(
p
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
?
at
::
ScalarType
::
Float
:
p
.
type
().
scalarType
()));
fused_lamb_cuda
(
p
,
p_copy
,
m
,
v
,
g
,
lr
,
beta1
,
beta2
,
max_coeff
,
min_coeff
,
eps
,
grad_scale
,
step
,
mode
,
bias_correction
,
decay
,
w_l2_i
,
u_l2_i
,
lamb_coeff_val
);
return
lamb_coeff_val
;
}
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"lamb"
,
&
lamb
,
"Adam optimized CUDA implementation with LAMB."
);
}
deepspeed/ops/csrc/lamb/fused_lamb_cuda_kernel.cu
0 → 100644
View file @
eadbbe09
/* Copyright 2019 The Microsoft DeepSpeed Team */
#include <cuda.h>
#include <cuda_runtime.h>
#include <stdio.h>
#include <cmath>
#include "ATen/ATen.h"
#include "ATen/TensorUtils.h"
#include "ATen/cuda/CUDAContext.h"
#include "ATen/cuda/detail/IndexUtils.cuh"
//#include "ATen/Type.h"
#include <THC/THCGeneral.h>
#include "ATen/AccumulateType.h"
#include <iostream>
//#include <helper_functions.h>
#include <cooperative_groups.h>
#include <cuda_runtime_api.h>
#include <stdio.h>
namespace
cg
=
cooperative_groups
;
// Utility class used to avoid linker errors with extern
// unsized shared memory arrays with templated type
namespace
{
// This is the un-specialized struct. Note that we prevent instantiation of this
// struct by putting an undefined symbol in the function body so it won't compile.
template
<
typename
T
>
struct
SharedMemory
{
// Ensure that we won't compile any un-specialized types
__device__
inline
operator
T
*
()
{
extern
__device__
void
error
(
void
);
error
();
return
NULL
;
}
};
template
<
>
struct
SharedMemory
<
float
>
{
__device__
inline
operator
float
*
()
{
extern
__shared__
float
s_float
[];
return
s_float
;
}
};
template
<
>
struct
SharedMemory
<
double
>
{
__device__
inline
operator
double
*
()
{
extern
__shared__
double
s_double
[];
return
s_double
;
}
};
}
// namespace
#include "type_shim.h"
typedef
enum
{
ADAM_MODE_0
=
0
,
// eps under square root
ADAM_MODE_1
=
1
// eps outside square root
}
adamMode_t
;
// s_a and s_b are in shared memory
// g_a and g_b are in shared memory
template
<
typename
T
,
int
blockSize
>
__device__
void
reduce_block_in_shared_memory
(
T
*
s_a
,
T
*
s_b
,
T
*
g_a
,
T
*
g_b
)
{
// Handle to thread block group
cg
::
thread_block
cta
=
cg
::
this_thread_block
();
// perform block reduction in shared memory,
unsigned
int
tid
=
cta
.
thread_rank
();
T
a_sum
=
s_a
[
tid
];
T
b_sum
=
s_b
[
tid
];
cg
::
sync
(
cta
);
// do reduction in shared mem
if
((
blockSize
>=
512
)
&&
(
tid
<
256
))
{
s_a
[
tid
]
=
a_sum
=
a_sum
+
s_a
[
tid
+
256
];
s_b
[
tid
]
=
b_sum
=
b_sum
+
s_b
[
tid
+
256
];
}
cg
::
sync
(
cta
);
if
((
blockSize
>=
256
)
&&
(
tid
<
128
))
{
s_a
[
tid
]
=
a_sum
=
a_sum
+
s_a
[
tid
+
128
];
s_b
[
tid
]
=
b_sum
=
b_sum
+
s_b
[
tid
+
128
];
}
cg
::
sync
(
cta
);
if
((
blockSize
>=
128
)
&&
(
tid
<
64
))
{
s_a
[
tid
]
=
a_sum
=
a_sum
+
s_a
[
tid
+
64
];
s_b
[
tid
]
=
b_sum
=
b_sum
+
s_b
[
tid
+
64
];
}
cg
::
sync
(
cta
);
#if (__CUDA_ARCH__ >= 300)
if
(
tid
<
32
)
{
cg
::
coalesced_group
active
=
cg
::
coalesced_threads
();
// Fetch final intermediate sum from 2nd warp
if
(
blockSize
>=
64
)
{
a_sum
=
a_sum
+
s_a
[
tid
+
32
];
b_sum
=
b_sum
+
s_b
[
tid
+
32
];
}
// Reduce final warp using shuffle
for
(
int
offset
=
warpSize
/
2
;
offset
>
0
;
offset
/=
2
)
{
a_sum
+=
active
.
shfl_down
(
a_sum
,
offset
);
b_sum
+=
active
.
shfl_down
(
b_sum
,
offset
);
}
}
#else
if
((
blockSize
>=
64
)
&&
(
tid
<
32
))
{
s_a
[
tid
]
=
a_sum
=
a_sum
+
s_a
[
tid
+
32
];
s_b
[
tid
]
=
b_sum
=
b_sum
+
s_b
[
tid
+
32
];
}
cg
::
sync
(
cta
);
if
((
blockSize
>=
32
)
&&
(
tid
<
16
))
{
s_a
[
tid
]
=
a_sum
=
a_sum
+
s_a
[
tid
+
16
];
s_b
[
tid
]
=
b_sum
=
b_sum
+
s_b
[
tid
+
16
];
}
cg
::
sync
(
cta
);
if
((
blockSize
>=
16
)
&&
(
tid
<
8
))
{
s_a
[
tid
]
=
a_sum
=
a_sum
+
s_a
[
tid
+
8
];
s_b
[
tid
]
=
b_sum
=
b_sum
+
s_b
[
tid
+
8
];
}
cg
::
sync
(
cta
);
if
((
blockSize
>=
8
)
&&
(
tid
<
4
))
{
s_a
[
tid
]
=
a_sum
=
a_sum
+
s_a
[
tid
+
4
];
s_b
[
tid
]
=
b_sum
=
b_sum
+
s_b
[
tid
+
4
];
}
cg
::
sync
(
cta
);
if
((
blockSize
>=
4
)
&&
(
tid
<
2
))
{
s_a
[
tid
]
=
a_sum
=
a_sum
+
s_a
[
tid
+
2
];
s_b
[
tid
]
=
b_sum
=
b_sum
+
s_b
[
tid
+
2
];
}
cg
::
sync
(
cta
);
if
((
blockSize
>=
2
)
&&
(
tid
<
1
))
{
s_a
[
tid
]
=
a_sum
=
a_sum
+
s_a
[
tid
+
1
];
s_b
[
tid
]
=
b_sum
=
b_sum
+
s_b
[
tid
+
1
];
}
cg
::
sync
(
cta
);
#endif
// write result for this block to global mem
if
(
tid
==
0
)
{
g_a
[
blockIdx
.
x
]
=
(
T
)
a_sum
;
g_b
[
blockIdx
.
x
]
=
(
T
)
b_sum
;
}
}
template
<
typename
T
,
int
blockSize
>
__device__
void
reduce_two_vectors_in_register
(
T
a
,
T
b
,
T
*
g_a
,
T
*
g_b
)
{
const
int
threadIdInBlock
=
cg
::
this_thread_block
().
thread_rank
();
T
*
s_a
=
SharedMemory
<
T
>
();
T
*
s_b
=
SharedMemory
<
T
>
()
+
cg
::
this_thread_block
().
size
();
s_a
[
threadIdInBlock
]
=
a
;
s_b
[
threadIdInBlock
]
=
b
;
reduce_block_in_shared_memory
<
T
,
blockSize
>
(
s_a
,
s_b
,
g_a
,
g_b
);
}
template
<
typename
T
,
typename
GRAD_T
,
int
blockSize
>
__global__
void
lamb_cuda_kernel_part1
(
T
*
__restrict__
p
,
GRAD_T
*
__restrict__
p_copy
,
// For mixed precision training, pass NULL if not needed
T
*
__restrict__
m
,
T
*
__restrict__
v
,
const
GRAD_T
*
__restrict__
g
,
const
float
b1
,
const
float
b2
,
const
float
eps
,
const
float
grad_scale
,
const
float
step_size
,
const
size_t
tsize
,
adamMode_t
mode
,
const
float
decay
,
T
*
__restrict__
w_l2_i
,
T
*
__restrict__
u_l2_i
)
{
// Assuming 2D grids and 2D blocks
const
int
blockId
=
gridDim
.
x
*
blockIdx
.
y
+
blockIdx
.
x
;
const
int
threadsPerBlock
=
blockDim
.
x
*
blockDim
.
y
;
const
int
threadIdInBlock
=
cg
::
this_thread_block
().
thread_rank
();
const
int
i
=
(
blockId
*
threadsPerBlock
+
threadIdInBlock
);
const
int
totThreads
=
gridDim
.
x
*
gridDim
.
y
*
threadsPerBlock
;
T
reg_w
=
0
;
T
reg_u
=
0
;
for
(
int
j
=
i
;
j
<
tsize
;
j
+=
totThreads
)
{
T
scaled_grad
=
g
[
j
]
/
grad_scale
;
T
pj
=
p
[
j
];
m
[
j
]
=
b1
*
m
[
j
]
+
(
1
-
b1
)
*
scaled_grad
;
v
[
j
]
=
b2
*
v
[
j
]
+
(
1
-
b2
)
*
scaled_grad
*
scaled_grad
;
float
denom
;
if
(
mode
==
ADAM_MODE_0
)
denom
=
sqrtf
(
v
[
j
]
+
eps
);
else
// Mode 1
denom
=
sqrtf
(
v
[
j
])
+
eps
;
T
update
=
(
m
[
j
]
/
denom
)
+
(
decay
*
p
[
j
]);
reg_u
+=
update
*
update
;
reg_w
+=
pj
*
pj
;
}
reduce_two_vectors_in_register
<
T
,
blockSize
>
(
reg_w
,
reg_u
,
w_l2_i
,
u_l2_i
);
}
template
<
typename
T
,
typename
GRAD_T
,
int
blockSize
>
__global__
void
lamb_cuda_kernel_part2
(
const
size_t
tsize
,
T
*
__restrict__
g_a
,
T
*
__restrict__
g_b
)
{
T
*
s_a
=
SharedMemory
<
T
>
();
T
*
s_b
=
SharedMemory
<
T
>
()
+
cg
::
this_thread_block
().
size
();
const
int
threadIdInBlock
=
cg
::
this_thread_block
().
thread_rank
();
s_a
[
threadIdInBlock
]
=
g_a
[
threadIdInBlock
];
s_b
[
threadIdInBlock
]
=
g_b
[
threadIdInBlock
];
if
(
threadIdInBlock
>=
tsize
)
{
s_a
[
threadIdInBlock
]
=
0.0
;
s_b
[
threadIdInBlock
]
=
0.0
;
}
reduce_block_in_shared_memory
<
T
,
blockSize
>
(
s_a
,
s_b
,
g_a
,
g_b
);
}
template
<
typename
T
,
typename
GRAD_T
>
__global__
void
lamb_cuda_kernel_part3
(
T
*
__restrict__
p
,
GRAD_T
*
__restrict__
p_copy
,
// For mixed precision training, pass NULL if not needed
T
*
__restrict__
m
,
T
*
__restrict__
v
,
const
GRAD_T
*
__restrict__
g
,
const
float
b1
,
const
float
b2
,
const
float
max_coeff
,
const
float
min_coeff
,
const
float
eps
,
const
float
grad_scale
,
const
float
step_size
,
const
size_t
tsize
,
adamMode_t
mode
,
const
float
decay
,
T
*
__restrict__
w_l2_i
,
T
*
__restrict__
u_l2_i
,
T
*
__restrict__
lamb_coeff_val
)
{
// Assuming 2D grids and 2D blocks
const
int
blockId
=
gridDim
.
x
*
blockIdx
.
y
+
blockIdx
.
x
;
const
int
threadsPerBlock
=
blockDim
.
x
*
blockDim
.
y
;
const
int
threadIdInBlock
=
cg
::
this_thread_block
().
thread_rank
();
const
int
i
=
(
blockId
*
threadsPerBlock
+
threadIdInBlock
);
const
int
totThreads
=
gridDim
.
x
*
gridDim
.
y
*
threadsPerBlock
;
T
reg_w
=
sqrtf
(
w_l2_i
[
0
]);
T
reg_u
=
sqrtf
(
u_l2_i
[
0
]);
float
lamb_coeff
=
1.0
;
if
(
reg_w
!=
0
and
reg_u
!=
0
)
{
lamb_coeff
=
reg_w
/
reg_u
;
if
(
lamb_coeff
>
max_coeff
)
{
lamb_coeff
=
max_coeff
;
}
if
(
lamb_coeff
<
min_coeff
)
{
lamb_coeff
=
min_coeff
;
}
}
if
(
blockId
==
0
and
threadIdInBlock
==
0
)
{
lamb_coeff_val
[
0
]
=
lamb_coeff
;
// printf("Cuda Lamb Coeff is %.6f \n",lamb_coeff);
}
for
(
int
j
=
i
;
j
<
tsize
;
j
+=
totThreads
)
{
T
pj
=
(
float
)
p
[
j
];
T
mj
=
m
[
j
];
T
vj
=
v
[
j
];
float
denom
;
if
(
mode
==
ADAM_MODE_0
)
denom
=
sqrtf
(
vj
+
eps
);
else
// Mode 1
denom
=
sqrtf
(
vj
)
+
eps
;
T
update
=
(
mj
/
denom
)
+
(
decay
*
pj
);
pj
=
pj
-
(
step_size
*
lamb_coeff
*
update
);
p
[
j
]
=
pj
;
if
(
p_copy
!=
NULL
)
p_copy
[
j
]
=
(
GRAD_T
)
pj
;
}
}
void
fused_lamb_cuda
(
at
::
Tensor
&
p
,
at
::
Tensor
&
p_copy
,
at
::
Tensor
&
m
,
at
::
Tensor
&
v
,
at
::
Tensor
&
g
,
float
lr
,
float
beta1
,
float
beta2
,
float
max_coeff
,
float
min_coeff
,
float
eps
,
float
grad_scale
,
int
step
,
int
mode
,
int
bias_correction
,
float
decay
,
at
::
Tensor
&
w_l2_i
,
at
::
Tensor
&
u_l2_i
,
at
::
Tensor
&
lamb_coeff
)
{
// using namespace at;
// Get tensor size
int
tsize
=
p
.
numel
();
// Determine #threads and #blocks
const
int
threadsPerBlock
=
512
;
int
num_blocks
=
(
tsize
+
threadsPerBlock
-
1
)
/
threadsPerBlock
;
if
(
num_blocks
>
512
)
num_blocks
=
512
;
int
smemsize
=
0
;
if
(
p
.
type
().
scalarType
()
==
at
::
ScalarType
::
Double
)
smemsize
=
2
*
threadsPerBlock
*
sizeof
(
double
);
else
smemsize
=
2
*
threadsPerBlock
*
sizeof
(
float
);
const
dim3
blocks
(
num_blocks
);
const
dim3
threads
(
threadsPerBlock
);
AT_ASSERTM
(
at
::
cuda
::
detail
::
canUse32BitIndexMath
(
p
),
"parameter tensor is too large to be indexed with int32"
);
// Constants
float
step_size
=
0
;
if
(
bias_correction
==
1
)
{
const
float
bias_correction1
=
1
-
std
::
pow
(
beta1
,
step
);
const
float
bias_correction2
=
1
-
std
::
pow
(
beta2
,
step
);
step_size
=
lr
*
std
::
sqrt
(
bias_correction2
)
/
bias_correction1
;
}
else
{
step_size
=
lr
;
}
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
if
(
g
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
)
{
// all other values should be fp32 for half gradients
AT_ASSERTM
(
p
.
type
().
scalarType
()
==
at
::
ScalarType
::
Float
,
"expected parameter to be of float type"
);
// dispatch is done on the gradient type
using
namespace
at
;
// prevents "toString is undefined" errors
AT_DISPATCH_FLOATING_TYPES_AND_HALF
(
g
.
scalar_type
(),
"lamb_cuda_kernel"
,
([
&
]
{
using
accscalar_t
=
at
::
acc_type
<
scalar_t
,
true
>
;
lamb_cuda_kernel_part1
<
accscalar_t
,
scalar_t
,
threadsPerBlock
>
<<<
blocks
,
threadsPerBlock
,
smemsize
,
stream
>>>
(
p
.
data
<
accscalar_t
>
(),
p_copy
.
numel
()
?
p_copy
.
data
<
scalar_t
>
()
:
NULL
,
m
.
data
<
accscalar_t
>
(),
v
.
data
<
accscalar_t
>
(),
g
.
data
<
scalar_t
>
(),
beta1
,
beta2
,
eps
,
grad_scale
,
step_size
,
tsize
,
(
adamMode_t
)
mode
,
decay
,
w_l2_i
.
data
<
accscalar_t
>
(),
u_l2_i
.
data
<
accscalar_t
>
());
lamb_cuda_kernel_part2
<
accscalar_t
,
scalar_t
,
threadsPerBlock
>
<<<
1
,
threadsPerBlock
,
smemsize
,
stream
>>>
(
num_blocks
,
w_l2_i
.
data
<
accscalar_t
>
(),
u_l2_i
.
data
<
accscalar_t
>
());
lamb_cuda_kernel_part3
<
accscalar_t
,
scalar_t
>
<<<
blocks
,
threadsPerBlock
,
smemsize
,
stream
>>>
(
p
.
data
<
accscalar_t
>
(),
p_copy
.
numel
()
?
p_copy
.
data
<
scalar_t
>
()
:
NULL
,
m
.
data
<
accscalar_t
>
(),
v
.
data
<
accscalar_t
>
(),
g
.
data
<
scalar_t
>
(),
beta1
,
beta2
,
max_coeff
,
min_coeff
,
eps
,
grad_scale
,
step_size
,
tsize
,
(
adamMode_t
)
mode
,
decay
,
w_l2_i
.
data
<
accscalar_t
>
(),
u_l2_i
.
data
<
accscalar_t
>
(),
lamb_coeff
.
data
<
accscalar_t
>
());
}));
}
else
{
using
namespace
at
;
AT_DISPATCH_FLOATING_TYPES
(
g
.
scalar_type
(),
"lamb_cuda_kernel"
,
([
&
]
{
lamb_cuda_kernel_part1
<
scalar_t
,
scalar_t
,
threadsPerBlock
>
<<<
blocks
,
threadsPerBlock
,
smemsize
,
stream
>>>
(
p
.
data
<
scalar_t
>
(),
NULL
,
// don't output p_copy for fp32, it's wasted write
m
.
data
<
scalar_t
>
(),
v
.
data
<
scalar_t
>
(),
g
.
data
<
scalar_t
>
(),
beta1
,
beta2
,
eps
,
grad_scale
,
step_size
,
tsize
,
(
adamMode_t
)
mode
,
decay
,
w_l2_i
.
data
<
scalar_t
>
(),
u_l2_i
.
data
<
scalar_t
>
());
lamb_cuda_kernel_part2
<
scalar_t
,
scalar_t
,
threadsPerBlock
>
<<<
1
,
threadsPerBlock
,
smemsize
,
stream
>>>
(
num_blocks
,
w_l2_i
.
data
<
scalar_t
>
(),
u_l2_i
.
data
<
scalar_t
>
());
lamb_cuda_kernel_part3
<
scalar_t
,
scalar_t
>
<<<
blocks
,
threadsPerBlock
,
smemsize
,
stream
>>>
(
p
.
data
<
scalar_t
>
(),
NULL
,
// don't output p_copy for fp32, it's wasted write
m
.
data
<
scalar_t
>
(),
v
.
data
<
scalar_t
>
(),
g
.
data
<
scalar_t
>
(),
beta1
,
beta2
,
max_coeff
,
min_coeff
,
eps
,
grad_scale
,
step_size
,
tsize
,
(
adamMode_t
)
mode
,
decay
,
w_l2_i
.
data
<
scalar_t
>
(),
u_l2_i
.
data
<
scalar_t
>
(),
lamb_coeff
.
data
<
scalar_t
>
());
}));
}
THCudaCheck
(
cudaGetLastError
());
}
// template __device__ void reduce_two_vectors_in_register<float,512>(float a, float b, float* g_a,
// float* g_b, cg::grid_group &cgg);
deepspeed/ops/csrc/lamb/hip/fused_lamb_hip.cpp
0 → 100644
View file @
eadbbe09
/* Copyright 2019 The Microsoft DeepSpeed Team */
#include <torch/extension.h>
// CUDA forward declaration
void
fused_lamb_cuda
(
at
::
Tensor
&
p
,
at
::
Tensor
&
p_copy
,
at
::
Tensor
&
m
,
at
::
Tensor
&
v
,
at
::
Tensor
&
g
,
float
lr
,
float
beta1
,
float
beta2
,
float
max_coeff
,
float
min_coeff
,
float
eps
,
float
grad_scale
,
int
step
,
int
mode
,
int
bias_correction
,
float
decay
,
at
::
Tensor
&
w_l2_i
,
at
::
Tensor
&
u_l2_i
,
at
::
Tensor
&
lamb_coeff_val
);
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) \
CHECK_CUDA(x); \
CHECK_CONTIGUOUS(x)
// C++ interface
at
::
Tensor
lamb
(
at
::
Tensor
&
p
,
at
::
Tensor
&
p_copy
,
at
::
Tensor
&
m
,
at
::
Tensor
&
v
,
at
::
Tensor
&
g
,
float
lr
,
float
beta1
,
float
beta2
,
float
max_coeff
,
float
min_coeff
,
float
eps
,
float
grad_scale
,
int
step
,
int
mode
,
int
bias_correction
,
float
decay
)
{
CHECK_INPUT
(
p
);
if
(
p_copy
.
numel
()
>
0
)
CHECK_INPUT
(
p_copy
);
CHECK_INPUT
(
m
);
CHECK_INPUT
(
v
);
CHECK_INPUT
(
g
);
int64_t
num_elem
=
p
.
numel
();
AT_ASSERTM
(
m
.
numel
()
==
num_elem
,
"number of elements in m and p tensors should be equal"
);
AT_ASSERTM
(
v
.
numel
()
==
num_elem
,
"number of elements in v and p tensors should be equal"
);
AT_ASSERTM
(
g
.
numel
()
==
num_elem
,
"number of elements in g and p tensors should be equal"
);
AT_ASSERTM
(
p_copy
.
numel
()
==
num_elem
||
p_copy
.
numel
()
==
0
,
"number of elements in p_copy and p tensors should be equal, or p_copy should be empty"
);
// intermediate for weight L2 reduction
// make sure that the threads per block is at least 512 during the kernel launch otherwise the
// behavious is unexpected
at
::
Tensor
w_l2_i
=
at
::
empty
(
{
512
},
p
.
options
().
dtype
(
p
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
?
at
::
ScalarType
::
Float
:
p
.
type
().
scalarType
()));
// intermediate for update L2 reduction
// make sure that the threads per block is at least 512 during the kernel launch otherwise the
// behavious is unexpected
at
::
Tensor
u_l2_i
=
at
::
empty
(
{
512
},
p
.
options
().
dtype
(
p
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
?
at
::
ScalarType
::
Float
:
p
.
type
().
scalarType
()));
at
::
Tensor
lamb_coeff_val
=
at
::
empty
(
{
1
},
p
.
options
().
dtype
(
p
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
?
at
::
ScalarType
::
Float
:
p
.
type
().
scalarType
()));
fused_lamb_cuda
(
p
,
p_copy
,
m
,
v
,
g
,
lr
,
beta1
,
beta2
,
max_coeff
,
min_coeff
,
eps
,
grad_scale
,
step
,
mode
,
bias_correction
,
decay
,
w_l2_i
,
u_l2_i
,
lamb_coeff_val
);
return
lamb_coeff_val
;
}
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"lamb"
,
&
lamb
,
"Adam optimized CUDA implementation with LAMB."
);
}
deepspeed/ops/csrc/lamb/hip/fused_lamb_hip_kernel.hip
0 → 100644
View file @
eadbbe09
/* Copyright 2019 The Microsoft DeepSpeed Team */
#include <hip/hip_runtime.h>
#include <hip/hip_runtime.h>
#include <stdio.h>
#include <cmath>
#include "ATen/ATen.h"
#include "ATen/TensorUtils.h"
#include "ATen/hip/HIPContext.h"
#include "ATen/hip/detail/IndexUtils.cuh"
//#include "ATen/Type.h"
#include <THH/THHGeneral.h>
#include "ATen/AccumulateType.h"
#include <iostream>
//#include <helper_functions.h>
#include <cooperative_groups.h>
#include <hip/hip_runtime_api.h>
#include <stdio.h>
namespace cg = cooperative_groups;
// Utility class used to avoid linker errors with extern
// unsized shared memory arrays with templated type
namespace {
// This is the un-specialized struct. Note that we prevent instantiation of this
// struct by putting an undefined symbol in the function body so it won't compile.
template <typename T>
struct SharedMemory {
// Ensure that we won't compile any un-specialized types
__device__ inline operator T*()
{
extern __device__ void error(void);
error();
return NULL;
}
};
template <>
struct SharedMemory<float> {
__device__ inline operator float*()
{
HIP_DYNAMIC_SHARED( float, s_float)
return s_float;
}
};
template <>
struct SharedMemory<double> {
__device__ inline operator double*()
{
HIP_DYNAMIC_SHARED( double, s_double)
return s_double;
}
};
} // namespace
#include "type_shim.h"
typedef enum {
ADAM_MODE_0 = 0, // eps under square root
ADAM_MODE_1 = 1 // eps outside square root
} adamMode_t;
// s_a and s_b are in shared memory
// g_a and g_b are in shared memory
template <typename T, int blockSize>
__device__ void reduce_block_in_shared_memory(T* s_a, T* s_b, T* g_a, T* g_b)
{
// Handle to thread block group
cg::thread_block cta = cg::this_thread_block();
// perform block reduction in shared memory,
unsigned int tid = cta.thread_rank();
T a_sum = s_a[tid];
T b_sum = s_b[tid];
cg::sync(cta);
// do reduction in shared mem
if ((blockSize >= 512) && (tid < 256)) {
s_a[tid] = a_sum = a_sum + s_a[tid + 256];
s_b[tid] = b_sum = b_sum + s_b[tid + 256];
}
cg::sync(cta);
if ((blockSize >= 256) && (tid < 128)) {
s_a[tid] = a_sum = a_sum + s_a[tid + 128];
s_b[tid] = b_sum = b_sum + s_b[tid + 128];
}
cg::sync(cta);
if ((blockSize >= 128) && (tid < 64)) {
s_a[tid] = a_sum = a_sum + s_a[tid + 64];
s_b[tid] = b_sum = b_sum + s_b[tid + 64];
}
cg::sync(cta);
#if (__CUDA_ARCH__ >= 300)
if (tid < 32) {
cg::coalesced_group active = cg::coalesced_threads();
// Fetch final intermediate sum from 2nd warp
if (blockSize >= 64) {
a_sum = a_sum + s_a[tid + 32];
b_sum = b_sum + s_b[tid + 32];
}
// Reduce final warp using shuffle
for (int offset = warpSize / 2; offset > 0; offset /= 2) {
a_sum += active.shfl_down(a_sum, offset);
b_sum += active.shfl_down(b_sum, offset);
}
}
#else
if ((blockSize >= 64) && (tid < 32)) {
s_a[tid] = a_sum = a_sum + s_a[tid + 32];
s_b[tid] = b_sum = b_sum + s_b[tid + 32];
}
cg::sync(cta);
if ((blockSize >= 32) && (tid < 16)) {
s_a[tid] = a_sum = a_sum + s_a[tid + 16];
s_b[tid] = b_sum = b_sum + s_b[tid + 16];
}
cg::sync(cta);
if ((blockSize >= 16) && (tid < 8)) {
s_a[tid] = a_sum = a_sum + s_a[tid + 8];
s_b[tid] = b_sum = b_sum + s_b[tid + 8];
}
cg::sync(cta);
if ((blockSize >= 8) && (tid < 4)) {
s_a[tid] = a_sum = a_sum + s_a[tid + 4];
s_b[tid] = b_sum = b_sum + s_b[tid + 4];
}
cg::sync(cta);
if ((blockSize >= 4) && (tid < 2)) {
s_a[tid] = a_sum = a_sum + s_a[tid + 2];
s_b[tid] = b_sum = b_sum + s_b[tid + 2];
}
cg::sync(cta);
if ((blockSize >= 2) && (tid < 1)) {
s_a[tid] = a_sum = a_sum + s_a[tid + 1];
s_b[tid] = b_sum = b_sum + s_b[tid + 1];
}
cg::sync(cta);
#endif
// write result for this block to global mem
if (tid == 0) {
g_a[blockIdx.x] = (T)a_sum;
g_b[blockIdx.x] = (T)b_sum;
}
}
template <typename T, int blockSize>
__device__ void reduce_two_vectors_in_register(T a, T b, T* g_a, T* g_b)
{
const int threadIdInBlock = cg::this_thread_block().thread_rank();
T* s_a = SharedMemory<T>();
T* s_b = SharedMemory<T>() + cg::this_thread_block().size();
s_a[threadIdInBlock] = a;
s_b[threadIdInBlock] = b;
reduce_block_in_shared_memory<T, blockSize>(s_a, s_b, g_a, g_b);
}
template <typename T, typename GRAD_T, int blockSize>
__global__ void lamb_cuda_kernel_part1(
T* __restrict__ p,
GRAD_T* __restrict__ p_copy, // For mixed precision training, pass NULL if not needed
T* __restrict__ m,
T* __restrict__ v,
const GRAD_T* __restrict__ g,
const float b1,
const float b2,
const float eps,
const float grad_scale,
const float step_size,
const size_t tsize,
adamMode_t mode,
const float decay,
T* __restrict__ w_l2_i,
T* __restrict__ u_l2_i)
{
// Assuming 2D grids and 2D blocks
const int blockId = gridDim.x * blockIdx.y + blockIdx.x;
const int threadsPerBlock = blockDim.x * blockDim.y;
const int threadIdInBlock = cg::this_thread_block().thread_rank();
const int i = (blockId * threadsPerBlock + threadIdInBlock);
const int totThreads = gridDim.x * gridDim.y * threadsPerBlock;
T reg_w = 0;
T reg_u = 0;
for (int j = i; j < tsize; j += totThreads) {
T scaled_grad = g[j] / grad_scale;
T pj = p[j];
m[j] = b1 * m[j] + (1 - b1) * scaled_grad;
v[j] = b2 * v[j] + (1 - b2) * scaled_grad * scaled_grad;
float denom;
if (mode == ADAM_MODE_0)
denom = sqrtf(v[j] + eps);
else // Mode 1
denom = sqrtf(v[j]) + eps;
T update = (m[j] / denom) + (decay * p[j]);
reg_u += update * update;
reg_w += pj * pj;
}
reduce_two_vectors_in_register<T, blockSize>(reg_w, reg_u, w_l2_i, u_l2_i);
}
template <typename T, typename GRAD_T, int blockSize>
__global__ void lamb_cuda_kernel_part2(const size_t tsize, T* __restrict__ g_a, T* __restrict__ g_b)
{
T* s_a = SharedMemory<T>();
T* s_b = SharedMemory<T>() + cg::this_thread_block().size();
const int threadIdInBlock = cg::this_thread_block().thread_rank();
s_a[threadIdInBlock] = g_a[threadIdInBlock];
s_b[threadIdInBlock] = g_b[threadIdInBlock];
if (threadIdInBlock >= tsize) {
s_a[threadIdInBlock] = 0.0;
s_b[threadIdInBlock] = 0.0;
}
reduce_block_in_shared_memory<T, blockSize>(s_a, s_b, g_a, g_b);
}
template <typename T, typename GRAD_T>
__global__ void lamb_cuda_kernel_part3(
T* __restrict__ p,
GRAD_T* __restrict__ p_copy, // For mixed precision training, pass NULL if not needed
T* __restrict__ m,
T* __restrict__ v,
const GRAD_T* __restrict__ g,
const float b1,
const float b2,
const float max_coeff,
const float min_coeff,
const float eps,
const float grad_scale,
const float step_size,
const size_t tsize,
adamMode_t mode,
const float decay,
T* __restrict__ w_l2_i,
T* __restrict__ u_l2_i,
T* __restrict__ lamb_coeff_val)
{
// Assuming 2D grids and 2D blocks
const int blockId = gridDim.x * blockIdx.y + blockIdx.x;
const int threadsPerBlock = blockDim.x * blockDim.y;
const int threadIdInBlock = cg::this_thread_block().thread_rank();
const int i = (blockId * threadsPerBlock + threadIdInBlock);
const int totThreads = gridDim.x * gridDim.y * threadsPerBlock;
T reg_w = sqrtf(w_l2_i[0]);
T reg_u = sqrtf(u_l2_i[0]);
float lamb_coeff = 1.0;
if (reg_w != 0 and reg_u != 0) {
lamb_coeff = reg_w / reg_u;
if (lamb_coeff > max_coeff) { lamb_coeff = max_coeff; }
if (lamb_coeff < min_coeff) { lamb_coeff = min_coeff; }
}
if (blockId == 0 and threadIdInBlock == 0) {
lamb_coeff_val[0] = lamb_coeff;
// printf("Cuda Lamb Coeff is %.6f \n",lamb_coeff);
}
for (int j = i; j < tsize; j += totThreads) {
T pj = (float)p[j];
T mj = m[j];
T vj = v[j];
float denom;
if (mode == ADAM_MODE_0)
denom = sqrtf(vj + eps);
else // Mode 1
denom = sqrtf(vj) + eps;
T update = (mj / denom) + (decay * pj);
pj = pj - (step_size * lamb_coeff * update);
p[j] = pj;
if (p_copy != NULL) p_copy[j] = (GRAD_T)pj;
}
}
void fused_lamb_cuda(at::Tensor& p,
at::Tensor& p_copy,
at::Tensor& m,
at::Tensor& v,
at::Tensor& g,
float lr,
float beta1,
float beta2,
float max_coeff,
float min_coeff,
float eps,
float grad_scale,
int step,
int mode,
int bias_correction,
float decay,
at::Tensor& w_l2_i,
at::Tensor& u_l2_i,
at::Tensor& lamb_coeff)
{
// using namespace at;
// Get tensor size
int tsize = p.numel();
// Determine #threads and #blocks
const int threadsPerBlock = 512;
int num_blocks = (tsize + threadsPerBlock - 1) / threadsPerBlock;
if (num_blocks > 512) num_blocks = 512;
int smemsize = 0;
if (p.type().scalarType() == at::ScalarType::Double)
smemsize = 2 * threadsPerBlock * sizeof(double);
else
smemsize = 2 * threadsPerBlock * sizeof(float);
const dim3 blocks(num_blocks);
const dim3 threads(threadsPerBlock);
AT_ASSERTM(at::cuda::detail::canUse32BitIndexMath(p),
"parameter tensor is too large to be indexed with int32");
// Constants
float step_size = 0;
if (bias_correction == 1) {
const float bias_correction1 = 1 - ::pow(beta1, step);
const float bias_correction2 = 1 - ::pow(beta2, step);
step_size = lr * std::sqrt(bias_correction2) / bias_correction1;
} else {
step_size = lr;
}
hipStream_t stream = at::hip::getCurrentHIPStreamMasqueradingAsCUDA();
if (g.type().scalarType() == at::ScalarType::Half) {
// all other values should be fp32 for half gradients
AT_ASSERTM(p.type().scalarType() == at::ScalarType::Float,
"expected parameter to be of float type");
// dispatch is done on the gradient type
using namespace at; // prevents "toString is undefined" errors
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
g.scalar_type(), "lamb_cuda_kernel", ([&] {
using accscalar_t = at::acc_type<scalar_t, true>;
hipLaunchKernelGGL(( lamb_cuda_kernel_part1<accscalar_t, scalar_t, threadsPerBlock>)
, dim3(blocks), dim3(threadsPerBlock), smemsize, stream,
p.data<accscalar_t>(),
p_copy.numel() ? p_copy.data<scalar_t>() : NULL,
m.data<accscalar_t>(),
v.data<accscalar_t>(),
g.data<scalar_t>(),
beta1,
beta2,
eps,
grad_scale,
step_size,
tsize,
(adamMode_t)mode,
decay,
w_l2_i.data<accscalar_t>(),
u_l2_i.data<accscalar_t>());
hipLaunchKernelGGL(( lamb_cuda_kernel_part2<accscalar_t, scalar_t, threadsPerBlock>)
, dim3(1), dim3(threadsPerBlock), smemsize, stream,
num_blocks, w_l2_i.data<accscalar_t>(), u_l2_i.data<accscalar_t>());
hipLaunchKernelGGL(( lamb_cuda_kernel_part3<accscalar_t, scalar_t>)
, dim3(blocks), dim3(threadsPerBlock), smemsize, stream,
p.data<accscalar_t>(),
p_copy.numel() ? p_copy.data<scalar_t>() : NULL,
m.data<accscalar_t>(),
v.data<accscalar_t>(),
g.data<scalar_t>(),
beta1,
beta2,
max_coeff,
min_coeff,
eps,
grad_scale,
step_size,
tsize,
(adamMode_t)mode,
decay,
w_l2_i.data<accscalar_t>(),
u_l2_i.data<accscalar_t>(),
lamb_coeff.data<accscalar_t>());
}));
} else {
using namespace at;
AT_DISPATCH_FLOATING_TYPES(
g.scalar_type(), "lamb_cuda_kernel", ([&] {
hipLaunchKernelGGL(( lamb_cuda_kernel_part1<scalar_t, scalar_t, threadsPerBlock>)
, dim3(blocks), dim3(threadsPerBlock), smemsize, stream,
p.data<scalar_t>(),
NULL, // don't output p_copy for fp32, it's wasted write
m.data<scalar_t>(),
v.data<scalar_t>(),
g.data<scalar_t>(),
beta1,
beta2,
eps,
grad_scale,
step_size,
tsize,
(adamMode_t)mode,
decay,
w_l2_i.data<scalar_t>(),
u_l2_i.data<scalar_t>());
hipLaunchKernelGGL(( lamb_cuda_kernel_part2<scalar_t, scalar_t, threadsPerBlock>)
, dim3(1), dim3(threadsPerBlock), smemsize, stream,
num_blocks, w_l2_i.data<scalar_t>(), u_l2_i.data<scalar_t>());
hipLaunchKernelGGL(( lamb_cuda_kernel_part3<scalar_t, scalar_t>)
, dim3(blocks), dim3(threadsPerBlock), smemsize, stream,
p.data<scalar_t>(),
NULL, // don't output p_copy for fp32, it's wasted write
m.data<scalar_t>(),
v.data<scalar_t>(),
g.data<scalar_t>(),
beta1,
beta2,
max_coeff,
min_coeff,
eps,
grad_scale,
step_size,
tsize,
(adamMode_t)mode,
decay,
w_l2_i.data<scalar_t>(),
u_l2_i.data<scalar_t>(),
lamb_coeff.data<scalar_t>());
}));
}
THCudaCheck(hipGetLastError());
}
// template __device__ void reduce_two_vectors_in_register<float,512>(float a, float b, float* g_a,
// float* g_b, cg::grid_group &cgg);
deepspeed/ops/csrc/sparse_attention/hip/utils.cpp
0 → 100644
View file @
eadbbe09
// DeepSpeed note, code taken & adapted from commit 9aa94789f13ada713af36cfd8cca2fc9a7f6b79a
// https://github.com/ptillet/torch-blocksparse/blob/master/csrc/utils.cpp
#include <torch/extension.h>
#include <string>
#include <tuple>
#include <vector>
#ifdef _OPENMP
#include <omp.h>
#endif
typedef
std
::
vector
<
std
::
tuple
<
int
,
torch
::
Tensor
>>
ret_t
;
void
segment_blocks
(
torch
::
Tensor
layout
,
torch
::
Tensor
idx
,
torch
::
Tensor
scratch
,
int
max_width
,
ret_t
&
ret
)
{
size_t
H
=
layout
.
size
(
0
);
size_t
M
=
layout
.
size
(
1
);
size_t
N
=
layout
.
size
(
2
);
torch
::
Tensor
tmp
=
torch
::
zeros_like
(
layout
);
auto
_tmp
=
tmp
.
accessor
<
int
,
3
>
();
auto
_layout
=
layout
.
accessor
<
int
,
3
>
();
auto
_idx
=
idx
.
accessor
<
int
,
3
>
();
auto
_scratch
=
scratch
.
accessor
<
int
,
3
>
();
std
::
vector
<
int
>
current
(
H
,
0
);
#ifdef _OPENMP
#pragma omp parallel for
#endif
for
(
size_t
h
=
0
;
h
<
H
;
h
++
)
{
// surrounding indices
std
::
vector
<
int
>
ii_left
(
max_width
,
-
1
);
std
::
vector
<
std
::
vector
<
int
>>
ii_top
(
max_width
,
std
::
vector
<
int
>
(
N
,
-
1
));
for
(
size_t
m
=
0
;
m
<
M
;
m
++
)
{
for
(
size_t
n
=
0
;
n
<
N
;
n
++
)
{
int
v
=
_layout
[
h
][
m
][
n
];
if
(
v
==
0
)
continue
;
int
n_left
=
ii_left
[
max_width
-
1
];
int
m_top
=
ii_top
[
max_width
-
1
][
n
];
int
top
=
(
m_top
>=
0
)
?
_tmp
[
h
][
m_top
][
n
]
:
0
;
int
left
=
(
n_left
>=
0
)
?
_tmp
[
h
][
m
][
n_left
]
:
0
;
int
topleft
=
(
m_top
>=
0
&&
n_left
>=
0
)
?
_tmp
[
h
][
m_top
][
n_left
]
:
0
;
int
width
=
std
::
min
(
left
,
std
::
min
(
top
,
topleft
))
+
1
;
// reset width if blocks cannot be
// packed together (i.e., there's a 1 "in the middle")
for
(
int
nn
=
n_left
+
1
;
nn
<
n
;
nn
++
)
if
(
ii_top
[
max_width
-
1
][
nn
]
>
ii_top
[
max_width
-
1
][
n
])
width
=
1
;
_tmp
[
h
][
m
][
n
]
=
width
;
// update n_left ring buffer
for
(
int
k
=
0
;
k
<
max_width
-
1
;
k
++
)
ii_left
[
k
]
=
ii_left
[
k
+
1
];
ii_left
[
max_width
-
1
]
=
n
;
// update ii_top ring buffer
for
(
int
k
=
0
;
k
<
max_width
-
1
;
k
++
)
ii_top
[
k
][
n
]
=
ii_top
[
k
+
1
][
n
];
ii_top
[
max_width
-
1
][
n
]
=
m
;
// block is too small -- skip
if
(
width
!=
max_width
)
continue
;
// retained blocks are set to zeros
for
(
size_t
km
=
0
;
km
<
max_width
;
km
++
)
for
(
size_t
kn
=
0
;
kn
<
max_width
;
kn
++
)
{
int
mm
=
ii_top
[
km
][
n
];
int
nn
=
ii_left
[
kn
];
if
(
mm
<
0
||
nn
<
0
)
continue
;
_layout
[
h
][
mm
][
nn
]
=
0
;
_tmp
[
h
][
mm
][
nn
]
=
0
;
_scratch
[
h
][
current
[
h
]][
0
]
=
(
int
)
h
;
_scratch
[
h
][
current
[
h
]][
1
]
=
(
int
)
mm
;
_scratch
[
h
][
current
[
h
]][
2
]
=
(
int
)
nn
;
_scratch
[
h
][
current
[
h
]][
3
]
=
_idx
[
h
][
mm
][
nn
];
current
[
h
]
++
;
}
}
}
}
std
::
vector
<
torch
::
Tensor
>
to_cat
;
for
(
size_t
h
=
0
;
h
<
H
;
h
++
)
if
(
current
[
h
]
>
0
)
to_cat
.
push_back
(
scratch
[
h
].
slice
(
0
,
0
,
current
[
h
]));
if
(
!
to_cat
.
empty
())
ret
.
push_back
({
max_width
,
torch
::
cat
(
to_cat
)});
}
ret_t
sdd_segment
(
torch
::
Tensor
layout
,
int
start_width
)
{
ret_t
ret
;
// block index
torch
::
Tensor
idx
=
torch
::
zeros_like
(
layout
);
int
current
=
0
;
size_t
H
=
layout
.
size
(
0
);
size_t
M
=
layout
.
size
(
1
);
size_t
N
=
layout
.
size
(
2
);
auto
_layout
=
layout
.
accessor
<
int
,
3
>
();
auto
_idx
=
idx
.
accessor
<
int
,
3
>
();
for
(
size_t
h
=
0
;
h
<
H
;
h
++
)
for
(
size_t
m
=
0
;
m
<
M
;
m
++
)
for
(
size_t
n
=
0
;
n
<
N
;
n
++
)
{
if
(
_layout
[
h
][
m
][
n
]
==
0
)
continue
;
_idx
[
h
][
m
][
n
]
=
current
++
;
}
// scratch memory
torch
::
Tensor
scratch
=
torch
::
empty
({
H
,
layout
.
sum
().
item
<
int
>
(),
4
},
layout
.
dtype
());
for
(
int
max_width
=
start_width
;
max_width
>
0
;
max_width
/=
2
)
segment_blocks
(
layout
,
idx
,
scratch
,
max_width
,
ret
);
return
ret
;
}
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"sdd_segment"
,
&
sdd_segment
,
"SDD segmentation handler"
);
}
deepspeed/ops/csrc/sparse_attention/utils.cpp
0 → 100644
View file @
eadbbe09
// DeepSpeed note, code taken & adapted from commit 9aa94789f13ada713af36cfd8cca2fc9a7f6b79a
// https://github.com/ptillet/torch-blocksparse/blob/master/csrc/utils.cpp
#include <torch/extension.h>
#include <string>
#include <tuple>
#include <vector>
#ifdef _OPENMP
#include <omp.h>
#endif
typedef
std
::
vector
<
std
::
tuple
<
int
,
torch
::
Tensor
>>
ret_t
;
void
segment_blocks
(
torch
::
Tensor
layout
,
torch
::
Tensor
idx
,
torch
::
Tensor
scratch
,
int
max_width
,
ret_t
&
ret
)
{
size_t
H
=
layout
.
size
(
0
);
size_t
M
=
layout
.
size
(
1
);
size_t
N
=
layout
.
size
(
2
);
torch
::
Tensor
tmp
=
torch
::
zeros_like
(
layout
);
auto
_tmp
=
tmp
.
accessor
<
int
,
3
>
();
auto
_layout
=
layout
.
accessor
<
int
,
3
>
();
auto
_idx
=
idx
.
accessor
<
int
,
3
>
();
auto
_scratch
=
scratch
.
accessor
<
int
,
3
>
();
std
::
vector
<
int
>
current
(
H
,
0
);
#ifdef _OPENMP
#pragma omp parallel for
#endif
for
(
size_t
h
=
0
;
h
<
H
;
h
++
)
{
// surrounding indices
std
::
vector
<
int
>
ii_left
(
max_width
,
-
1
);
std
::
vector
<
std
::
vector
<
int
>>
ii_top
(
max_width
,
std
::
vector
<
int
>
(
N
,
-
1
));
for
(
size_t
m
=
0
;
m
<
M
;
m
++
)
{
for
(
size_t
n
=
0
;
n
<
N
;
n
++
)
{
int
v
=
_layout
[
h
][
m
][
n
];
if
(
v
==
0
)
continue
;
int
n_left
=
ii_left
[
max_width
-
1
];
int
m_top
=
ii_top
[
max_width
-
1
][
n
];
int
top
=
(
m_top
>=
0
)
?
_tmp
[
h
][
m_top
][
n
]
:
0
;
int
left
=
(
n_left
>=
0
)
?
_tmp
[
h
][
m
][
n_left
]
:
0
;
int
topleft
=
(
m_top
>=
0
&&
n_left
>=
0
)
?
_tmp
[
h
][
m_top
][
n_left
]
:
0
;
int
width
=
std
::
min
(
left
,
std
::
min
(
top
,
topleft
))
+
1
;
// reset width if blocks cannot be
// packed together (i.e., there's a 1 "in the middle")
for
(
int
nn
=
n_left
+
1
;
nn
<
n
;
nn
++
)
if
(
ii_top
[
max_width
-
1
][
nn
]
>
ii_top
[
max_width
-
1
][
n
])
width
=
1
;
_tmp
[
h
][
m
][
n
]
=
width
;
// update n_left ring buffer
for
(
int
k
=
0
;
k
<
max_width
-
1
;
k
++
)
ii_left
[
k
]
=
ii_left
[
k
+
1
];
ii_left
[
max_width
-
1
]
=
n
;
// update ii_top ring buffer
for
(
int
k
=
0
;
k
<
max_width
-
1
;
k
++
)
ii_top
[
k
][
n
]
=
ii_top
[
k
+
1
][
n
];
ii_top
[
max_width
-
1
][
n
]
=
m
;
// block is too small -- skip
if
(
width
!=
max_width
)
continue
;
// retained blocks are set to zeros
for
(
size_t
km
=
0
;
km
<
max_width
;
km
++
)
for
(
size_t
kn
=
0
;
kn
<
max_width
;
kn
++
)
{
int
mm
=
ii_top
[
km
][
n
];
int
nn
=
ii_left
[
kn
];
if
(
mm
<
0
||
nn
<
0
)
continue
;
_layout
[
h
][
mm
][
nn
]
=
0
;
_tmp
[
h
][
mm
][
nn
]
=
0
;
_scratch
[
h
][
current
[
h
]][
0
]
=
(
int
)
h
;
_scratch
[
h
][
current
[
h
]][
1
]
=
(
int
)
mm
;
_scratch
[
h
][
current
[
h
]][
2
]
=
(
int
)
nn
;
_scratch
[
h
][
current
[
h
]][
3
]
=
_idx
[
h
][
mm
][
nn
];
current
[
h
]
++
;
}
}
}
}
std
::
vector
<
torch
::
Tensor
>
to_cat
;
for
(
size_t
h
=
0
;
h
<
H
;
h
++
)
if
(
current
[
h
]
>
0
)
to_cat
.
push_back
(
scratch
[
h
].
slice
(
0
,
0
,
current
[
h
]));
if
(
!
to_cat
.
empty
())
ret
.
push_back
({
max_width
,
torch
::
cat
(
to_cat
)});
}
ret_t
sdd_segment
(
torch
::
Tensor
layout
,
int
start_width
)
{
ret_t
ret
;
// block index
torch
::
Tensor
idx
=
torch
::
zeros_like
(
layout
);
int
current
=
0
;
size_t
H
=
layout
.
size
(
0
);
size_t
M
=
layout
.
size
(
1
);
size_t
N
=
layout
.
size
(
2
);
auto
_layout
=
layout
.
accessor
<
int
,
3
>
();
auto
_idx
=
idx
.
accessor
<
int
,
3
>
();
for
(
size_t
h
=
0
;
h
<
H
;
h
++
)
for
(
size_t
m
=
0
;
m
<
M
;
m
++
)
for
(
size_t
n
=
0
;
n
<
N
;
n
++
)
{
if
(
_layout
[
h
][
m
][
n
]
==
0
)
continue
;
_idx
[
h
][
m
][
n
]
=
current
++
;
}
// scratch memory
torch
::
Tensor
scratch
=
torch
::
empty
({
H
,
layout
.
sum
().
item
<
int
>
(),
4
},
layout
.
dtype
());
for
(
int
max_width
=
start_width
;
max_width
>
0
;
max_width
/=
2
)
segment_blocks
(
layout
,
idx
,
scratch
,
max_width
,
ret
);
return
ret
;
}
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"sdd_segment"
,
&
sdd_segment
,
"SDD segmentation handler"
);
}
deepspeed/ops/csrc/transformer/cublas_wrappers.cu
0 → 100644
View file @
eadbbe09
#include "cublas_wrappers.h"
int
cublas_gemm_ex
(
cublasHandle_t
handle
,
cublasOperation_t
transa
,
cublasOperation_t
transb
,
int
m
,
int
n
,
int
k
,
const
float
*
alpha
,
const
float
*
beta
,
const
float
*
A
,
const
float
*
B
,
float
*
C
,
cublasGemmAlgo_t
algo
)
{
cublasStatus_t
status
=
cublasGemmEx
(
handle
,
transa
,
transb
,
m
,
n
,
k
,
(
const
void
*
)
alpha
,
(
const
void
*
)
A
,
CUDA_R_32F
,
(
transa
==
CUBLAS_OP_N
)
?
m
:
k
,
(
const
void
*
)
B
,
CUDA_R_32F
,
(
transb
==
CUBLAS_OP_N
)
?
k
:
n
,
(
const
void
*
)
beta
,
C
,
CUDA_R_32F
,
m
,
CUDA_R_32F
,
algo
);
if
(
status
!=
CUBLAS_STATUS_SUCCESS
)
{
fprintf
(
stderr
,
"!!!! kernel execution error. (m: %d, n: %d, k: %d, error: %d)
\n
"
,
m
,
n
,
k
,
(
int
)
status
);
return
EXIT_FAILURE
;
}
return
0
;
}
int
cublas_gemm_ex
(
cublasHandle_t
handle
,
cublasOperation_t
transa
,
cublasOperation_t
transb
,
int
m
,
int
n
,
int
k
,
const
float
*
alpha
,
const
float
*
beta
,
const
__half
*
A
,
const
__half
*
B
,
__half
*
C
,
cublasGemmAlgo_t
algo
)
{
cublasStatus_t
status
=
cublasGemmEx
(
handle
,
transa
,
transb
,
m
,
n
,
k
,
(
const
void
*
)
alpha
,
(
const
void
*
)
A
,
CUDA_R_16F
,
(
transa
==
CUBLAS_OP_N
)
?
m
:
k
,
(
const
void
*
)
B
,
CUDA_R_16F
,
(
transb
==
CUBLAS_OP_N
)
?
k
:
n
,
(
const
void
*
)
beta
,
(
void
*
)
C
,
CUDA_R_16F
,
m
,
CUDA_R_32F
,
algo
);
if
(
status
!=
CUBLAS_STATUS_SUCCESS
)
{
fprintf
(
stderr
,
"!!!! kernel execution error. (m: %d, n: %d, k: %d, error: %d)
\n
"
,
m
,
n
,
k
,
(
int
)
status
);
return
EXIT_FAILURE
;
}
return
0
;
}
int
cublas_strided_batched_gemm
(
cublasHandle_t
handle
,
int
m
,
int
n
,
int
k
,
const
float
*
alpha
,
const
float
*
beta
,
const
float
*
A
,
const
float
*
B
,
float
*
C
,
cublasOperation_t
op_A
,
cublasOperation_t
op_B
,
int
stride_A
,
int
stride_B
,
int
stride_C
,
int
batch
,
cublasGemmAlgo_t
algo
)
{
cublasStatus_t
status
=
cublasGemmStridedBatchedEx
(
handle
,
op_A
,
op_B
,
m
,
n
,
k
,
alpha
,
A
,
CUDA_R_32F
,
(
op_A
==
CUBLAS_OP_N
)
?
m
:
k
,
stride_A
,
B
,
CUDA_R_32F
,
(
op_B
==
CUBLAS_OP_N
)
?
k
:
n
,
stride_B
,
beta
,
C
,
CUDA_R_32F
,
m
,
stride_C
,
batch
,
CUDA_R_32F
,
algo
);
if
(
status
!=
CUBLAS_STATUS_SUCCESS
)
{
fprintf
(
stderr
,
"!!!! kernel execution error. (batch: %d, m: %d, n: %d, k: %d, error: %d)
\n
"
,
batch
,
m
,
n
,
k
,
(
int
)
status
);
return
EXIT_FAILURE
;
}
return
0
;
}
int
cublas_strided_batched_gemm
(
cublasHandle_t
handle
,
int
m
,
int
n
,
int
k
,
const
float
*
alpha
,
const
float
*
beta
,
const
__half
*
A
,
const
__half
*
B
,
__half
*
C
,
cublasOperation_t
op_A
,
cublasOperation_t
op_B
,
int
stride_A
,
int
stride_B
,
int
stride_C
,
int
batch
,
cublasGemmAlgo_t
algo
)
{
cublasStatus_t
status
=
cublasGemmStridedBatchedEx
(
handle
,
op_A
,
op_B
,
m
,
n
,
k
,
alpha
,
A
,
CUDA_R_16F
,
(
op_A
==
CUBLAS_OP_N
)
?
m
:
k
,
stride_A
,
B
,
CUDA_R_16F
,
(
op_B
==
CUBLAS_OP_N
)
?
k
:
n
,
stride_B
,
beta
,
C
,
CUDA_R_16F
,
m
,
stride_C
,
batch
,
CUDA_R_32F
,
algo
);
if
(
status
!=
CUBLAS_STATUS_SUCCESS
)
{
fprintf
(
stderr
,
"!!!! kernel execution error. (m: %d, n: %d, k: %d, error: %d)
\n
"
,
m
,
n
,
k
,
(
int
)
status
);
return
EXIT_FAILURE
;
}
return
0
;
}
deepspeed/ops/csrc/transformer/dropout_kernels.cu
0 → 100755
View file @
eadbbe09
#include "custom_cuda_layers.h"
const
int
unroll_factor
=
4
;
__global__
void
dropout_kernel
(
const
int
N
,
const
float
ratio
,
float
*
out
,
const
float
*
Xdata
,
uint8_t
*
mask
,
std
::
pair
<
uint64_t
,
uint64_t
>
seed
)
{
const
float
scale
=
1.
/
(
1.
-
ratio
);
int
idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
curandStatePhilox4_32_10_t
state
;
curand_init
(
seed
.
first
,
idx
,
seed
.
second
,
&
state
);
CUDA_1D_KERNEL_LOOP
(
j
,
N
/
unroll_factor
)
{
float4
rand
=
curand_uniform4
(
&
state
);
uint8_t
m
[
unroll_factor
];
m
[
0
]
=
(
uint8_t
)(
rand
.
x
>
ratio
);
m
[
1
]
=
(
uint8_t
)(
rand
.
y
>
ratio
);
m
[
2
]
=
(
uint8_t
)(
rand
.
z
>
ratio
);
m
[
3
]
=
(
uint8_t
)(
rand
.
w
>
ratio
);
int
i
=
j
*
unroll_factor
;
mask
[
i
]
=
(
uint8_t
)
m
[
0
];
mask
[
i
+
1
]
=
(
uint8_t
)
m
[
1
];
mask
[
i
+
2
]
=
(
uint8_t
)
m
[
2
];
mask
[
i
+
3
]
=
(
uint8_t
)
m
[
3
];
out
[
i
]
=
Xdata
[
i
]
*
scale
*
m
[
0
];
out
[
i
+
1
]
=
Xdata
[
i
+
1
]
*
scale
*
m
[
1
];
out
[
i
+
2
]
=
Xdata
[
i
+
2
]
*
scale
*
m
[
2
];
out
[
i
+
3
]
=
Xdata
[
i
+
3
]
*
scale
*
m
[
3
];
}
int
high_index
=
((((
N
/
unroll_factor
)
-
1
)
/
blockDim
.
x
+
1
)
*
(
unroll_factor
*
blockDim
.
x
))
+
threadIdx
.
x
;
if
(
N
>
high_index
)
{
float4
rand
=
curand_uniform4
(
&
state
);
float
*
rand_data
=
&
(
rand
.
x
);
int
k
=
0
;
for
(
int
i
=
high_index
;
i
<
N
;
i
++
)
{
uint8_t
m
=
(
uint8_t
)(
rand_data
[
k
++
]
>
ratio
);
out
[
i
]
=
Xdata
[
i
]
*
scale
*
m
;
mask
[
i
]
=
m
;
}
}
}
__global__
void
dropout_kernel
(
const
int
N
,
const
float
ratio
,
__half
*
out
,
const
__half
*
Xdata
,
uint8_t
*
mask
,
std
::
pair
<
uint64_t
,
uint64_t
>
seed
)
{
const
float
scale
=
1.
/
(
1.
-
ratio
);
int
idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
curandStatePhilox4_32_10_t
state
;
curand_init
(
seed
.
first
,
idx
,
seed
.
second
,
&
state
);
#ifdef __STOCHASTIC_MODE__
const
__half2
h_scale
=
__float2half2_rn
(
scale
);
const
float2
*
x_cast
=
reinterpret_cast
<
const
float2
*>
(
Xdata
);
float2
*
out_cast
=
reinterpret_cast
<
float2
*>
(
out
);
uint32_t
*
mask_cast
=
reinterpret_cast
<
uint32_t
*>
(
mask
);
uint32_t
m_32
;
uint8_t
*
m
=
reinterpret_cast
<
uint8_t
*>
(
&
m_32
);
float2
result_f
;
__half2
*
result_h
=
reinterpret_cast
<
__half2
*>
(
&
result_f
);
__half2
mask_h
[
2
];
float2
mask_f
[
2
];
CUDA_1D_KERNEL_LOOP
(
j
,
N
/
unroll_factor
)
{
float2
x_f
=
x_cast
[
j
];
__half2
*
x_h
=
reinterpret_cast
<
__half2
*>
(
&
x_f
);
float4
rand
=
curand_uniform4
(
&
state
);
m
[
0
]
=
(
uint8_t
)(
rand
.
x
>
ratio
);
m
[
1
]
=
(
uint8_t
)(
rand
.
y
>
ratio
);
m
[
2
]
=
(
uint8_t
)(
rand
.
z
>
ratio
);
m
[
3
]
=
(
uint8_t
)(
rand
.
w
>
ratio
);
float
*
mask_f_data
=
&
mask_f
[
0
].
x
;
#pragma unroll
for
(
int
i
=
0
;
i
<
unroll_factor
;
i
++
)
mask_f_data
[
i
]
=
(
float
)(
m
[
i
]);
mask_h
[
0
]
=
__float22half2_rn
(
mask_f
[
0
]);
mask_h
[
1
]
=
__float22half2_rn
(
mask_f
[
1
]);
result_h
[
0
]
=
x_h
[
0
]
*
h_scale
*
mask_h
[
0
];
result_h
[
1
]
=
x_h
[
1
]
*
h_scale
*
mask_h
[
1
];
out_cast
[
j
]
=
result_f
;
mask_cast
[
j
]
=
m_32
;
}
#else
CUDA_1D_KERNEL_LOOP
(
j
,
N
/
unroll_factor
)
{
int
i
=
j
*
unroll_factor
;
const
__half2
*
vals_half
=
reinterpret_cast
<
const
__half2
*>
(
Xdata
+
i
);
float2
vals_half_f
[
2
];
vals_half_f
[
0
]
=
__half22float2
(
vals_half
[
0
]);
vals_half_f
[
1
]
=
__half22float2
(
vals_half
[
1
]);
uint8_t
m
[
unroll_factor
];
float4
rand
=
curand_uniform4
(
&
state
);
m
[
0
]
=
(
uint8_t
)(
rand
.
x
>
ratio
);
m
[
1
]
=
(
uint8_t
)(
rand
.
y
>
ratio
);
m
[
2
]
=
(
uint8_t
)(
rand
.
z
>
ratio
);
m
[
3
]
=
(
uint8_t
)(
rand
.
w
>
ratio
);
out
[
i
]
=
__float2half
(
vals_half_f
[
0
].
x
*
scale
*
m
[
0
]);
out
[
i
+
1
]
=
__float2half
(
vals_half_f
[
0
].
y
*
scale
*
m
[
1
]);
out
[
i
+
2
]
=
__float2half
(
vals_half_f
[
1
].
x
*
scale
*
m
[
2
]);
out
[
i
+
3
]
=
__float2half
(
vals_half_f
[
1
].
y
*
scale
*
m
[
3
]);
mask
[
i
]
=
m
[
0
];
mask
[
i
+
1
]
=
m
[
1
];
mask
[
i
+
2
]
=
m
[
2
];
mask
[
i
+
3
]
=
m
[
3
];
}
#endif
int
high_index
=
((((
N
/
unroll_factor
)
-
1
)
/
blockDim
.
x
+
1
)
*
(
unroll_factor
*
blockDim
.
x
))
+
threadIdx
.
x
;
if
(
N
>
high_index
)
{
float4
rand
=
curand_uniform4
(
&
state
);
float
*
rand_data
=
&
(
rand
.
x
);
int
k
=
0
;
for
(
int
i
=
high_index
;
i
<
N
;
i
++
)
{
uint8_t
m
=
(
uint8_t
)(
rand_data
[
k
++
]
>
ratio
);
out
[
i
]
=
__float2half
((
float
)
Xdata
[
i
]
*
scale
*
m
);
mask
[
i
]
=
m
;
}
}
}
__global__
void
dropout_kernel_bwd
(
const
int
N
,
const
float
ratio
,
const
float
*
Xdata
,
float
*
out
,
uint8_t
*
mask
,
std
::
pair
<
uint64_t
,
uint64_t
>
seed
)
{
const
float
scale
=
1.
/
(
1.
-
ratio
);
CUDA_1D_KERNEL_LOOP
(
j
,
N
/
unroll_factor
)
{
int
i
=
j
*
unroll_factor
;
out
[
i
]
=
mask
[
i
]
?
Xdata
[
i
]
*
scale
:
0.0
;
out
[
i
+
1
]
=
mask
[
i
+
1
]
?
Xdata
[
i
+
1
]
*
scale
:
0.0
;
out
[
i
+
2
]
=
mask
[
i
+
2
]
?
Xdata
[
i
+
2
]
*
scale
:
0.0
;
out
[
i
+
3
]
=
mask
[
i
+
3
]
?
Xdata
[
i
+
3
]
*
scale
:
0.0
;
}
int
high_index
=
((((
N
/
unroll_factor
)
-
1
)
/
blockDim
.
x
+
1
)
*
(
unroll_factor
*
blockDim
.
x
))
+
threadIdx
.
x
;
if
(
N
>
high_index
)
{
for
(
int
i
=
high_index
;
i
<
N
;
i
++
)
{
out
[
i
]
=
mask
[
i
]
?
Xdata
[
i
]
*
scale
:
0.0
;
}
}
}
__global__
void
dropout_kernel_bwd
(
const
int
N
,
const
float
ratio
,
const
__half
*
Xdata
,
__half
*
out
,
uint8_t
*
mask
,
std
::
pair
<
uint64_t
,
uint64_t
>
seed
)
{
const
float
scale
=
1.
/
(
1.
-
ratio
);
#ifdef __STOCHASTIC_MODE__
const
__half2
h_scale
=
__float2half2_rn
(
scale
);
const
float2
*
x_cast
=
reinterpret_cast
<
const
float2
*>
(
Xdata
);
float2
*
out_cast
=
reinterpret_cast
<
float2
*>
(
out
);
uint32_t
*
mask_cast
=
reinterpret_cast
<
uint32_t
*>
(
mask
);
CUDA_1D_KERNEL_LOOP
(
j
,
N
/
unroll_factor
)
{
float2
x_f
=
x_cast
[
j
];
__half2
*
x_h
=
reinterpret_cast
<
__half2
*>
(
&
x_f
);
uint32_t
m_32
=
mask_cast
[
j
];
uint8_t
*
m
=
(
uint8_t
*
)
&
m_32
;
__half2
mask_h
[
2
];
float2
mask_f
[
2
];
float
*
mask_f_data
=
&
mask_f
[
0
].
x
;
#pragma unroll
for
(
int
i
=
0
;
i
<
unroll_factor
;
i
++
)
mask_f_data
[
i
]
=
(
float
)(
m
[
i
]);
#pragma unroll
for
(
int
i
=
0
;
i
<
2
;
i
++
)
mask_h
[
i
]
=
__float22half2_rn
(
mask_f
[
i
]);
float2
result_f
;
__half2
*
result_h
=
reinterpret_cast
<
__half2
*>
(
&
result_f
);
result_h
[
0
]
=
x_h
[
0
]
*
h_scale
*
mask_h
[
0
];
result_h
[
1
]
=
x_h
[
1
]
*
h_scale
*
mask_h
[
1
];
out_cast
[
j
]
=
result_f
;
}
#else
const
__half
h_scale
=
__float2half
(
scale
);
const
__half
h_zero
=
__float2half
(
0.0
);
CUDA_1D_KERNEL_LOOP
(
j
,
N
/
unroll_factor
)
{
int
i
=
j
*
unroll_factor
;
const
__half2
*
vals_half
=
reinterpret_cast
<
const
__half2
*>
(
Xdata
+
i
);
uint8_t
*
m
=
mask
+
i
;
float2
vals_half_f
[
2
];
vals_half_f
[
0
]
=
__half22float2
(
vals_half
[
0
]);
vals_half_f
[
1
]
=
__half22float2
(
vals_half
[
1
]);
out
[
i
]
=
__float2half
(
vals_half_f
[
0
].
x
*
scale
*
m
[
0
]);
out
[
i
+
1
]
=
__float2half
(
vals_half_f
[
0
].
y
*
scale
*
m
[
1
]);
out
[
i
+
2
]
=
__float2half
(
vals_half_f
[
1
].
x
*
scale
*
m
[
2
]);
out
[
i
+
3
]
=
__float2half
(
vals_half_f
[
1
].
y
*
scale
*
m
[
3
]);
}
#endif
int
high_index
=
((((
N
/
unroll_factor
)
-
1
)
/
blockDim
.
x
+
1
)
*
(
unroll_factor
*
blockDim
.
x
))
+
threadIdx
.
x
;
if
(
N
>
high_index
)
{
for
(
int
i
=
high_index
;
i
<
N
;
i
++
)
{
out
[
i
]
=
__float2half
((
float
)
Xdata
[
i
]
*
scale
*
mask
[
i
]);
}
}
}
template
<
typename
T
>
void
launch_dropout
(
T
*
out
,
const
T
*
vals
,
uint8_t
*
mask
,
int
total_count
,
int
dim
,
float
ratio
,
cudaStream_t
stream
,
bool
bwd
)
{
assert
(
unroll_factor
==
4
);
dim3
grid_dim
=
DS_GET_BLOCKS
(
total_count
/
unroll_factor
);
dim3
block_dim
=
DS_CUDA_NUM_THREADS
;
if
(
dim
>
512
)
{
block_dim
.
x
>>=
1
;
grid_dim
.
x
<<=
1
;
}
uint64_t
inc
=
total_count
/
grid_dim
.
x
/
block_dim
.
x
;
std
::
pair
<
uint64_t
,
uint64_t
>
seed
=
Context
::
Instance
().
IncrementOffset
(
inc
);
if
(
bwd
)
dropout_kernel_bwd
<<<
grid_dim
,
block_dim
,
0
,
stream
>>>
(
total_count
,
ratio
,
vals
,
out
,
mask
,
seed
);
else
dropout_kernel
<<<
grid_dim
,
block_dim
,
0
,
stream
>>>
(
total_count
,
ratio
,
out
,
vals
,
mask
,
seed
);
}
template
void
launch_dropout
(
float
*
out
,
const
float
*
vals
,
uint8_t
*
mask
,
int
total_count
,
int
dim
,
float
ratio
,
cudaStream_t
stream
,
bool
);
template
void
launch_dropout
(
__half
*
out
,
const
__half
*
vals
,
uint8_t
*
mask
,
int
total_count
,
int
dim
,
float
ratio
,
cudaStream_t
stream
,
bool
);
__global__
void
dropout_grad_kernel
(
const
int
N
,
const
float
scale
,
float
*
Xdata
,
uint8_t
*
mask
)
{
CUDA_1D_KERNEL_LOOP
(
i
,
N
)
{
Xdata
[
i
]
*=
scale
*
mask
[
i
];
}
}
__global__
void
dropout_grad_kernel
(
const
int
N
,
const
float
scale
,
__half
*
Xdata
,
uint8_t
*
mask
)
{
const
__half2
h_scale
=
__float2half2_rn
(
scale
);
float2
*
x_cast
=
reinterpret_cast
<
float2
*>
(
Xdata
);
uint32_t
*
mask_cast
=
reinterpret_cast
<
uint32_t
*>
(
mask
);
CUDA_1D_KERNEL_LOOP
(
j
,
N
/
unroll_factor
)
{
float2
x_data
=
x_cast
[
j
];
uint32_t
m_32
=
mask_cast
[
j
];
uint8_t
*
m
=
(
uint8_t
*
)
&
m_32
;
float2
result_f
;
__half2
*
result_h
=
reinterpret_cast
<
__half2
*>
(
&
result_f
);
#ifdef __STOCHASTIC_MODE__
__half2
*
x_data_h
=
reinterpret_cast
<
__half2
*>
(
&
x_data
);
__half2
mask_h
[
2
];
float2
mask_f
[
2
];
float
*
mask_f_data
=
&
mask_f
[
0
].
x
;
#pragma unroll
for
(
int
i
=
0
;
i
<
unroll_factor
;
i
++
)
*
(
mask_f_data
++
)
=
(
float
)(
m
[
i
]);
mask_h
[
0
]
=
__float22half2_rn
(
mask_f
[
0
]);
mask_h
[
1
]
=
__float22half2_rn
(
mask_f
[
1
]);
result_h
[
0
]
=
x_data_h
[
0
]
*
h_scale
*
mask_h
[
0
];
result_h
[
1
]
=
x_data_h
[
1
]
*
h_scale
*
mask_h
[
1
];
#else
__half
*
x_data_h
=
reinterpret_cast
<
__half
*>
(
&
x_data
);
float2
result
[
2
];
result
[
0
].
x
=
(
float
)
x_data_h
[
0
]
*
scale
*
m
[
0
];
result
[
0
].
y
=
(
float
)
x_data_h
[
1
]
*
scale
*
m
[
1
];
result
[
1
].
x
=
(
float
)
x_data_h
[
2
]
*
scale
*
m
[
2
];
result
[
1
].
y
=
(
float
)
x_data_h
[
3
]
*
scale
*
m
[
3
];
result_h
[
0
]
=
__float22half2_rn
(
result
[
0
]);
result_h
[
1
]
=
__float22half2_rn
(
result
[
1
]);
#endif
x_cast
[
j
]
=
result_f
;
}
int
high_index
=
((((
N
/
unroll_factor
)
-
1
)
/
blockDim
.
x
+
1
)
*
(
unroll_factor
*
blockDim
.
x
))
+
threadIdx
.
x
;
if
(
N
>
high_index
)
{
for
(
int
i
=
high_index
;
i
<
N
;
i
++
)
{
Xdata
[
i
]
=
__float2half
((
float
)
Xdata
[
i
]
*
scale
*
mask
[
i
]);
}
}
}
template
<
typename
T
>
void
launch_dropout_grad
(
T
*
vals
,
uint8_t
*
mask
,
int
total_count
,
float
ratio
,
cudaStream_t
stream
)
{
assert
(
unroll_factor
==
4
);
const
float
scale
=
1.
/
(
1.
-
ratio
);
dropout_grad_kernel
<<<
DS_GET_BLOCKS
(
total_count
/
unroll_factor
),
DS_CUDA_NUM_THREADS
,
0
,
stream
>>>
(
total_count
,
scale
,
vals
,
mask
);
}
template
void
launch_dropout_grad
(
float
*
vals
,
uint8_t
*
mask
,
int
total_count
,
float
ratio
,
cudaStream_t
stream
);
template
void
launch_dropout_grad
(
__half
*
vals
,
uint8_t
*
mask
,
int
total_count
,
float
ratio
,
cudaStream_t
stream
);
__global__
void
dropout_grad_kernel
(
const
int
N
,
const
float
scale
,
const
float
*
Xdata
,
float
*
out
,
uint8_t
*
mask
)
{
CUDA_1D_KERNEL_LOOP
(
i
,
N
)
{
out
[
i
]
=
Xdata
[
i
]
*
scale
*
mask
[
i
];
}
}
__global__
void
dropout_grad_kernel
(
const
int
N
,
const
float
scale
,
const
__half
*
Xdata
,
__half
*
out
,
uint8_t
*
mask
)
{
const
float2
*
x_cast
=
reinterpret_cast
<
const
float2
*>
(
Xdata
);
float2
*
out_cast
=
reinterpret_cast
<
float2
*>
(
out
);
const
uint32_t
*
mask_cast
=
reinterpret_cast
<
const
uint32_t
*>
(
mask
);
float2
result_f
;
__half2
*
result_h
=
reinterpret_cast
<
__half2
*>
(
&
result_f
);
CUDA_1D_KERNEL_LOOP
(
j
,
N
/
unroll_factor
)
{
float2
x_data
=
x_cast
[
j
];
uint32_t
m_32
=
mask_cast
[
j
];
uint8_t
*
m
=
(
uint8_t
*
)
&
m_32
;
__half
*
x_data_h
=
reinterpret_cast
<
__half
*>
(
&
x_data
);
float2
result
[
2
];
result
[
0
].
x
=
(
float
)
x_data_h
[
0
]
*
scale
*
m
[
0
];
result
[
0
].
y
=
(
float
)
x_data_h
[
1
]
*
scale
*
m
[
1
];
result
[
1
].
x
=
(
float
)
x_data_h
[
2
]
*
scale
*
m
[
2
];
result
[
1
].
y
=
(
float
)
x_data_h
[
3
]
*
scale
*
m
[
3
];
result_h
[
0
]
=
__float22half2_rn
(
result
[
0
]);
result_h
[
1
]
=
__float22half2_rn
(
result
[
1
]);
out_cast
[
j
]
=
result_f
;
}
int
high_index
=
((((
N
/
unroll_factor
)
-
1
)
/
blockDim
.
x
+
1
)
*
(
unroll_factor
*
blockDim
.
x
))
+
threadIdx
.
x
;
if
(
N
>
high_index
)
{
for
(
int
i
=
high_index
;
i
<
N
;
i
++
)
{
out
[
i
]
=
__float2half
((
float
)
Xdata
[
i
]
*
scale
*
mask
[
i
]);
}
}
}
template
<
typename
T
>
void
launch_dropout_grad
(
T
*
vals_out
,
const
T
*
vals
,
uint8_t
*
mask
,
int
total_count
,
float
ratio
,
cudaStream_t
stream
)
{
assert
(
unroll_factor
==
4
);
const
float
scale
=
1.
/
(
1.
-
ratio
);
dropout_grad_kernel
<<<
DS_GET_BLOCKS
(
total_count
/
unroll_factor
),
DS_CUDA_NUM_THREADS
,
0
,
stream
>>>
(
total_count
,
scale
,
vals
,
vals_out
,
mask
);
}
template
void
launch_dropout_grad
(
float
*
,
const
float
*
vals
,
uint8_t
*
mask
,
int
total_count
,
float
ratio
,
cudaStream_t
stream
);
template
void
launch_dropout_grad
(
__half
*
,
const
__half
*
vals
,
uint8_t
*
mask
,
int
total_count
,
float
ratio
,
cudaStream_t
stream
);
__global__
void
dropout_kernel
(
const
int
N
,
const
int
dim
,
const
float
ratio
,
const
float
*
bias
,
float
*
Xdata
,
uint8_t
*
mask
,
std
::
pair
<
uint64_t
,
uint64_t
>
seed
)
{
const
float
scale
=
1.
/
(
1.
-
ratio
);
int
idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
int
tid
=
threadIdx
.
x
%
(
dim
/
unroll_factor
);
curandStatePhilox4_32_10_t
state
;
curand_init
(
seed
.
first
,
idx
,
seed
.
second
,
&
state
);
float4
*
Xdata_cast
=
reinterpret_cast
<
float4
*>
(
Xdata
);
uint32_t
*
mask_32
=
reinterpret_cast
<
uint32_t
*>
(
mask
);
const
float4
*
bias_cast
=
reinterpret_cast
<
const
float4
*>
(
bias
);
CUDA_1D_KERNEL_LOOP
(
j
,
N
)
{
float4
rand
=
curand_uniform4
(
&
state
);
uint32_t
m_32
;
uint8_t
*
m
=
(
uint8_t
*
)
&
m_32
;
m
[
0
]
=
(
uint8_t
)(
rand
.
x
>
ratio
);
m
[
1
]
=
(
uint8_t
)(
rand
.
y
>
ratio
);
m
[
2
]
=
(
uint8_t
)(
rand
.
z
>
ratio
);
m
[
3
]
=
(
uint8_t
)(
rand
.
w
>
ratio
);
float4
x_data
=
Xdata_cast
[
j
];
float4
b_data
=
bias_cast
[
j
%
(
dim
/
unroll_factor
)];
x_data
.
x
+=
b_data
.
x
;
x_data
.
y
+=
b_data
.
y
;
x_data
.
z
+=
b_data
.
z
;
x_data
.
w
+=
b_data
.
w
;
x_data
.
x
=
x_data
.
x
*
scale
*
m
[
0
];
x_data
.
y
=
x_data
.
y
*
scale
*
m
[
1
];
x_data
.
z
=
x_data
.
z
*
scale
*
m
[
2
];
x_data
.
w
=
x_data
.
w
*
scale
*
m
[
3
];
mask_32
[
j
]
=
m_32
;
Xdata_cast
[
j
]
=
x_data
;
}
int
high_index
=
((((
N
/
unroll_factor
)
-
1
)
/
blockDim
.
x
+
1
)
*
(
unroll_factor
*
blockDim
.
x
))
+
threadIdx
.
x
;
if
(
N
>
high_index
)
{
float4
rand
=
curand_uniform4
(
&
state
);
float
*
rand_data
=
&
(
rand
.
x
);
int
k
=
0
;
for
(
int
i
=
high_index
;
i
<
N
;
i
++
)
{
float
x_data
=
Xdata
[
i
]
+
bias
[
i
%
dim
];
uint8_t
m
=
(
uint8_t
)(
rand_data
[
k
++
]
>
ratio
);
Xdata
[
i
]
=
x_data
*
scale
*
m
;
mask
[
i
]
=
m
;
}
}
}
__global__
void
dropout_kernel
(
const
int
N
,
const
int
dim
,
const
float
ratio
,
const
__half
*
bias
,
__half
*
Xdata
,
uint8_t
*
mask
,
std
::
pair
<
uint64_t
,
uint64_t
>
seed
)
{
const
float
scale
=
1.
/
(
1.
-
ratio
);
int
idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
int
tid
=
threadIdx
.
x
%
(
dim
/
unroll_factor
);
curandStatePhilox4_32_10_t
state
;
curand_init
(
seed
.
first
,
idx
,
seed
.
second
,
&
state
);
float2
*
Xdata_cast
=
reinterpret_cast
<
float2
*>
(
Xdata
);
uint32_t
*
mask_32
=
reinterpret_cast
<
uint32_t
*>
(
mask
);
const
float2
*
bias_cast
=
reinterpret_cast
<
const
float2
*>
(
bias
);
CUDA_1D_KERNEL_LOOP
(
j
,
N
)
{
float4
rand
=
curand_uniform4
(
&
state
);
float2
data_f
;
__half2
*
data_h
=
reinterpret_cast
<
__half2
*>
(
&
data_f
);
float2
bias_f
;
__half2
*
bias_h
=
reinterpret_cast
<
__half2
*>
(
&
bias_f
);
data_f
=
Xdata_cast
[
j
];
bias_f
=
bias_cast
[
j
%
(
dim
/
unroll_factor
)];
float2
data_h_0
=
__half22float2
(
data_h
[
0
]);
float2
data_h_1
=
__half22float2
(
data_h
[
1
]);
float2
bias_h_0
=
__half22float2
(
bias_h
[
0
]);
float2
bias_h_1
=
__half22float2
(
bias_h
[
1
]);
data_h_0
.
x
+=
bias_h_0
.
x
;
data_h_0
.
y
+=
bias_h_0
.
y
;
data_h_1
.
x
+=
bias_h_1
.
x
;
data_h_1
.
y
+=
bias_h_1
.
y
;
uint32_t
m_32
;
uint8_t
*
m
=
(
uint8_t
*
)
&
m_32
;
m
[
0
]
=
(
uint8_t
)(
rand
.
x
>
ratio
);
m
[
1
]
=
(
uint8_t
)(
rand
.
y
>
ratio
);
m
[
2
]
=
(
uint8_t
)(
rand
.
z
>
ratio
);
m
[
3
]
=
(
uint8_t
)(
rand
.
w
>
ratio
);
data_h_0
.
x
=
__float2half
(
data_h_0
.
x
*
scale
*
m
[
0
]);
data_h_0
.
y
=
__float2half
(
data_h_0
.
y
*
scale
*
m
[
1
]);
data_h_1
.
x
=
__float2half
(
data_h_1
.
x
*
scale
*
m
[
2
]);
data_h_1
.
y
=
__float2half
(
data_h_1
.
y
*
scale
*
m
[
3
]);
float2
result_f
;
__half2
*
result_h
=
reinterpret_cast
<
__half2
*>
(
&
result_f
);
result_h
[
0
]
=
__float22half2_rn
(
data_h_0
);
result_h
[
1
]
=
__float22half2_rn
(
data_h_1
);
Xdata_cast
[
j
]
=
result_f
;
mask_32
[
j
]
=
m_32
;
}
int
high_index
=
((((
N
/
unroll_factor
)
-
1
)
/
blockDim
.
x
+
1
)
*
(
unroll_factor
*
blockDim
.
x
))
+
threadIdx
.
x
;
if
(
N
>
high_index
)
{
float4
rand
=
curand_uniform4
(
&
state
);
float
*
rand_data
=
&
(
rand
.
x
);
int
k
=
0
;
for
(
int
i
=
high_index
;
i
<
N
;
i
++
)
{
float
x_data
=
(
float
)
Xdata
[
i
]
+
(
float
)
bias
[
i
%
dim
];
uint8_t
m
=
(
uint8_t
)(
rand_data
[
k
++
]
>
ratio
);
Xdata
[
i
]
=
__float2half
(
x_data
*
scale
*
m
);
mask
[
i
]
=
m
;
}
}
}
template
<
typename
T
>
void
launch_dropout
(
T
*
out
,
const
T
*
bias
,
uint8_t
*
mask
,
int
batch
,
int
dim
,
float
ratio
,
cudaStream_t
stream
)
{
assert
(
unroll_factor
==
4
);
int
total_count
=
batch
*
dim
/
unroll_factor
;
dim3
grid_dim
=
DS_GET_BLOCKS
(
total_count
);
dim3
block_dim
=
DS_CUDA_NUM_THREADS
;
uint64_t
inc
=
(
batch
*
dim
)
/
grid_dim
.
x
/
block_dim
.
x
;
std
::
pair
<
uint64_t
,
uint64_t
>
seed
=
Context
::
Instance
().
IncrementOffset
(
inc
);
dropout_kernel
<<<
grid_dim
,
block_dim
,
0
,
stream
>>>
(
total_count
,
dim
,
ratio
,
bias
,
out
,
mask
,
seed
);
}
template
void
launch_dropout
(
float
*
,
const
float
*
bias
,
uint8_t
*
mask
,
int
batch
,
int
dim
,
float
ratio
,
cudaStream_t
stream
);
template
void
launch_dropout
(
__half
*
,
const
__half
*
bias
,
uint8_t
*
mask
,
int
batch
,
int
dim
,
float
ratio
,
cudaStream_t
stream
);
__global__
void
dropout_kernel
(
const
int
N
,
const
int
dim
,
const
float
ratio
,
const
float
*
input
,
const
float
*
residual
,
const
float
*
bias
,
float
*
out
,
uint8_t
*
mask
,
std
::
pair
<
uint64_t
,
uint64_t
>
seed
)
{
const
float
scale
=
1.
/
(
1.
-
ratio
);
int
idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
int
tid
=
threadIdx
.
x
%
(
dim
/
unroll_factor
);
curandStatePhilox4_32_10_t
state
;
curand_init
(
seed
.
first
,
idx
,
seed
.
second
,
&
state
);
float4
*
out_cast
=
reinterpret_cast
<
float4
*>
(
out
);
uint32_t
*
mask_32
=
reinterpret_cast
<
uint32_t
*>
(
mask
);
const
float4
*
bias_cast
=
reinterpret_cast
<
const
float4
*>
(
bias
);
const
float4
*
residual_cast
=
reinterpret_cast
<
const
float4
*>
(
residual
);
const
float4
*
input_cast
=
reinterpret_cast
<
const
float4
*>
(
input
);
CUDA_1D_KERNEL_LOOP
(
j
,
N
)
{
float4
rand
=
curand_uniform4
(
&
state
);
uint32_t
m_32
;
uint8_t
*
m
=
(
uint8_t
*
)
&
m_32
;
m
[
0
]
=
(
uint8_t
)(
rand
.
x
>
ratio
);
m
[
1
]
=
(
uint8_t
)(
rand
.
y
>
ratio
);
m
[
2
]
=
(
uint8_t
)(
rand
.
z
>
ratio
);
m
[
3
]
=
(
uint8_t
)(
rand
.
w
>
ratio
);
float4
out_data
;
float4
b_data
=
bias_cast
[
j
%
(
dim
/
unroll_factor
)];
float4
res_data
=
residual_cast
[
j
];
float4
inp_data
=
input_cast
[
j
];
out_data
.
x
=
(
b_data
.
x
+
inp_data
.
x
);
out_data
.
y
=
(
b_data
.
y
+
inp_data
.
y
);
out_data
.
z
=
(
b_data
.
z
+
inp_data
.
z
);
out_data
.
w
=
(
b_data
.
w
+
inp_data
.
w
);
out_data
.
x
=
out_data
.
x
*
scale
*
m
[
0
];
out_data
.
y
=
out_data
.
y
*
scale
*
m
[
1
];
out_data
.
z
=
out_data
.
z
*
scale
*
m
[
2
];
out_data
.
w
=
out_data
.
w
*
scale
*
m
[
3
];
out_data
.
x
+=
res_data
.
x
;
out_data
.
y
+=
res_data
.
y
;
out_data
.
z
+=
res_data
.
z
;
out_data
.
w
+=
res_data
.
w
;
mask_32
[
j
]
=
m_32
;
out_cast
[
j
]
=
out_data
;
}
int
high_index
=
((((
N
/
unroll_factor
)
-
1
)
/
blockDim
.
x
+
1
)
*
(
unroll_factor
*
blockDim
.
x
))
+
threadIdx
.
x
;
if
(
N
>
high_index
)
{
float4
rand
=
curand_uniform4
(
&
state
);
float
*
rand_data
=
&
(
rand
.
x
);
int
k
=
0
;
for
(
int
i
=
high_index
;
i
<
N
;
i
++
)
{
float
x_data
=
input
[
i
]
+
bias
[
i
%
dim
];
uint8_t
m
=
(
uint8_t
)(
rand_data
[
k
++
]
>
ratio
);
x_data
=
x_data
*
scale
*
m
;
x_data
+=
residual
[
i
];
out
[
i
]
=
x_data
;
mask
[
i
]
=
m
;
}
}
}
__global__
void
dropout_kernel
(
const
int
N
,
const
int
dim
,
const
float
ratio
,
const
__half
*
input
,
const
__half
*
residual
,
const
__half
*
bias
,
__half
*
out
,
uint8_t
*
mask
,
std
::
pair
<
uint64_t
,
uint64_t
>
seed
)
{
const
float
scale
=
1.
/
(
1.
-
ratio
);
int
idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
int
tid
=
threadIdx
.
x
%
(
dim
/
unroll_factor
);
curandStatePhilox4_32_10_t
state
;
curand_init
(
seed
.
first
,
idx
,
seed
.
second
,
&
state
);
float2
*
out_cast
=
reinterpret_cast
<
float2
*>
(
out
);
uint32_t
*
mask_32
=
reinterpret_cast
<
uint32_t
*>
(
mask
);
const
float2
*
bias_cast
=
reinterpret_cast
<
const
float2
*>
(
bias
);
const
float2
*
residual_cast
=
reinterpret_cast
<
const
float2
*>
(
residual
);
const
float2
*
input_cast
=
reinterpret_cast
<
const
float2
*>
(
input
);
CUDA_1D_KERNEL_LOOP
(
j
,
N
)
{
float4
rand
=
curand_uniform4
(
&
state
);
float2
data_f
;
__half2
*
data_h
=
reinterpret_cast
<
__half2
*>
(
&
data_f
);
float2
bias_f
;
__half2
*
bias_h
=
reinterpret_cast
<
__half2
*>
(
&
bias_f
);
float2
residual_f
;
__half2
*
residual_h
=
reinterpret_cast
<
__half2
*>
(
&
residual_f
);
float2
input_f
;
__half2
*
input_h
=
reinterpret_cast
<
__half2
*>
(
&
input_f
);
bias_f
=
bias_cast
[
j
%
(
dim
/
unroll_factor
)];
residual_f
=
residual_cast
[
j
];
input_f
=
input_cast
[
j
];
float2
data_h_0
=
__half22float2
(
data_h
[
0
]);
float2
data_h_1
=
__half22float2
(
data_h
[
1
]);
float2
bias_h_0
=
__half22float2
(
bias_h
[
0
]);
float2
bias_h_1
=
__half22float2
(
bias_h
[
1
]);
float2
residual_h_0
=
__half22float2
(
residual_h
[
0
]);
float2
residual_h_1
=
__half22float2
(
residual_h
[
1
]);
float2
input_h_0
=
__half22float2
(
input_h
[
0
]);
float2
input_h_1
=
__half22float2
(
input_h
[
1
]);
data_h_0
.
x
=
(
bias_h_0
.
x
+
input_h_0
.
x
);
data_h_0
.
y
=
(
bias_h_0
.
y
+
input_h_0
.
y
);
data_h_1
.
x
=
(
bias_h_1
.
x
+
input_h_1
.
x
);
data_h_1
.
y
=
(
bias_h_1
.
y
+
input_h_1
.
y
);
uint32_t
m_32
;
uint8_t
*
m
=
(
uint8_t
*
)
&
m_32
;
m
[
0
]
=
(
uint8_t
)(
rand
.
x
>
ratio
);
m
[
1
]
=
(
uint8_t
)(
rand
.
y
>
ratio
);
m
[
2
]
=
(
uint8_t
)(
rand
.
z
>
ratio
);
m
[
3
]
=
(
uint8_t
)(
rand
.
w
>
ratio
);
data_h_0
.
x
=
__float2half
(
data_h_0
.
x
*
scale
*
m
[
0
]);
data_h_0
.
y
=
__float2half
(
data_h_0
.
y
*
scale
*
m
[
1
]);
data_h_1
.
x
=
__float2half
(
data_h_1
.
x
*
scale
*
m
[
2
]);
data_h_1
.
y
=
__float2half
(
data_h_1
.
y
*
scale
*
m
[
3
]);
data_h_0
.
x
+=
residual_h_0
.
x
;
data_h_0
.
y
+=
residual_h_0
.
y
;
data_h_1
.
x
+=
residual_h_1
.
x
;
data_h_1
.
y
+=
residual_h_1
.
y
;
float2
result_f
;
__half2
*
result_h
=
reinterpret_cast
<
__half2
*>
(
&
result_f
);
result_h
[
0
]
=
__float22half2_rn
(
data_h_0
);
result_h
[
1
]
=
__float22half2_rn
(
data_h_1
);
out_cast
[
j
]
=
result_f
;
mask_32
[
j
]
=
m_32
;
}
int
high_index
=
((((
N
/
unroll_factor
)
-
1
)
/
blockDim
.
x
+
1
)
*
(
unroll_factor
*
blockDim
.
x
))
+
threadIdx
.
x
;
if
(
N
>
high_index
)
{
float4
rand
=
curand_uniform4
(
&
state
);
float
*
rand_data
=
&
(
rand
.
x
);
int
k
=
0
;
for
(
int
i
=
high_index
;
i
<
N
;
i
++
)
{
float
x_data
=
(
float
)
input
[
i
]
+
(
float
)
bias
[
i
%
dim
];
uint8_t
m
=
(
uint8_t
)(
rand_data
[
k
++
]
>
ratio
);
x_data
=
x_data
*
scale
*
m
;
x_data
+=
(
float
)
residual
[
i
];
out
[
i
]
=
__float2half
(
x_data
);
mask
[
i
]
=
m
;
}
}
}
template
<
typename
T
>
void
launch_dropout
(
T
*
out
,
const
T
*
input
,
const
T
*
residual
,
const
T
*
bias
,
uint8_t
*
mask
,
int
batch
,
int
dim
,
float
ratio
,
cudaStream_t
stream
)
{
assert
(
unroll_factor
==
4
);
int
total_count
=
batch
*
dim
/
unroll_factor
;
dim3
grid_dim
=
DS_GET_BLOCKS
(
total_count
);
dim3
block_dim
=
DS_CUDA_NUM_THREADS
;
uint64_t
inc
=
(
batch
*
dim
)
/
grid_dim
.
x
/
block_dim
.
x
;
std
::
pair
<
uint64_t
,
uint64_t
>
seed
=
Context
::
Instance
().
IncrementOffset
(
inc
);
dropout_kernel
<<<
grid_dim
,
block_dim
,
0
,
stream
>>>
(
total_count
,
dim
,
ratio
,
input
,
residual
,
bias
,
out
,
mask
,
seed
);
}
template
void
launch_dropout
(
float
*
,
const
float
*
,
const
float
*
residual
,
const
float
*
bias
,
uint8_t
*
mask
,
int
batch
,
int
dim
,
float
ratio
,
cudaStream_t
stream
);
template
void
launch_dropout
(
__half
*
,
const
__half
*
,
const
__half
*
residual
,
const
__half
*
bias
,
uint8_t
*
mask
,
int
batch
,
int
dim
,
float
ratio
,
cudaStream_t
stream
);
deepspeed/ops/csrc/transformer/ds_transformer_cuda.cpp
0 → 100644
View file @
eadbbe09
#include <torch/extension.h>
#include <cublas_v2.h>
#include <cuda_fp16.h>
#include <cuda_runtime.h>
#include <type_traits>
#include <unordered_map>
#include <vector>
#include "Timer.h"
#include "context.h"
#include "cublas_wrappers.h"
#include "custom_cuda_layers.h"
#include "ds_transformer_cuda.h"
static
std
::
unordered_map
<
int
,
std
::
shared_ptr
<
void
>>
s_transformer_layers
;
const
int
init_seq_length
=
128
;
// C++ interface
template
<
typename
T
>
size_t
get_workspace_size
(
int
maxBatchSize
,
int
seq_len
,
int
hidden_size
,
int
intermediate_size
,
int
heads
,
bool
training
,
bool
gelu_checkpoint
)
{
size_t
workSpacesize
=
4
*
(
size_t
(
maxBatchSize
)
*
seq_len
*
hidden_size
);
if
(
training
)
{
workSpacesize
+=
((
std
::
max
)((
size_t
(
maxBatchSize
)
*
seq_len
*
intermediate_size
),
2
*
(
size_t
(
maxBatchSize
)
*
heads
*
seq_len
*
seq_len
)));
if
(
gelu_checkpoint
)
workSpacesize
+=
2
*
(
size_t
(
maxBatchSize
)
*
seq_len
*
intermediate_size
);
}
return
workSpacesize
;
// * sizeof(T);
}
// NOTE: AT_ASSERT has become AT_CHECK on master after 0.4.
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) \
CHECK_CUDA(x); \
CHECK_CONTIGUOUS(x)
template
<
typename
T
>
BertTransformerLayer
<
T
>::
BertTransformerLayer
(
int
layer_id
,
int
batch_size
,
int
hidden_size
,
int
num_heads
,
int
intermediate_size
,
int
seq_length
,
float
attn_prob_dropout_ratio
,
float
hidden_output_dropout_ratio
,
float
layer_norm_eps
,
bool
pre_or_postLayerNorm
,
const
std
::
vector
<
std
::
array
<
int
,
3
>>&
gemm_algos
,
bool
attn_dropout_checkpoint
,
bool
normalize_invertible
,
bool
gelu_checkpoint
,
bool
stochastic_mode
)
:
_layer_id
(
layer_id
),
_batch_size
(
batch_size
),
_hidden_size
(
hidden_size
),
_heads
(
num_heads
),
_intermediate_size
(
intermediate_size
),
_seq_length
(
seq_length
),
_training
(
true
),
_pre_or_postLayerNorm
(
pre_or_postLayerNorm
),
_attn_dropout_checkpoint
(
attn_dropout_checkpoint
),
_normalize_invertible
(
normalize_invertible
),
_gelu_checkpoint
(
gelu_checkpoint
),
_stochastic_mode
(
stochastic_mode
),
_stream
(
Context
::
Instance
().
GetCurrentStream
()),
_cublasHandle
(
Context
::
Instance
().
GetCublasHandle
()),
_qkv_linear
(
typename
FeedForward
<
T
>::
Config
(
batch_size
*
seq_length
,
3
*
hidden_size
,
hidden_size
,
gemm_algos
[
0
])),
_attn_out_linear
(
typename
FeedForward
<
T
>::
Config
(
batch_size
*
seq_length
,
hidden_size
,
hidden_size
,
gemm_algos
[
0
])),
_attn_layer_norm
(
typename
Normalize_Layer
<
T
>::
Config
(
batch_size
,
seq_length
,
hidden_size
,
layer_norm_eps
,
true
,
!
normalize_invertible
)),
_layer_norm
(
typename
Normalize_Layer
<
T
>::
Config
(
batch_size
,
seq_length
,
hidden_size
,
layer_norm_eps
,
true
,
!
normalize_invertible
)),
_ff1
(
typename
FeedForward
<
T
>::
Config
(
batch_size
*
seq_length
,
_intermediate_size
,
hidden_size
,
gemm_algos
[
1
])),
_ff2
(
typename
FeedForward
<
T
>::
Config
(
batch_size
*
seq_length
,
hidden_size
,
_intermediate_size
,
gemm_algos
[
2
])),
_softmax
(
typename
Softmax
<
T
>::
Config
(
batch_size
,
num_heads
,
seq_length
)),
_gelu
(
typename
Gelu
<
T
>::
Config
(
_intermediate_size
)),
_attn_prob_dropout
(
typename
Dropout
<
T
>::
Config
(
attn_prob_dropout_ratio
,
_seq_length
)),
_attn_output_dropout
(
typename
Dropout
<
T
>::
Config
(
hidden_output_dropout_ratio
,
_hidden_size
)),
_layer_output_dropout
(
typename
Dropout
<
T
>::
Config
(
hidden_output_dropout_ratio
,
_hidden_size
)),
_attn_scores
(
typename
StridedBatchGemm
<
T
>::
Config
(
_batch_size
*
_heads
,
_seq_length
,
_seq_length
,
_hidden_size
/
_heads
,
(
T
(
1.0
)
/
T
(
sqrt
(
_hidden_size
/
_heads
))),
T
(
0.0
),
CUBLAS_OP_T
,
CUBLAS_OP_N
,
gemm_algos
[
3
])),
_attn_context
(
typename
StridedBatchGemm
<
T
>::
Config
(
_batch_size
*
_heads
,
_hidden_size
/
_heads
,
_seq_length
,
_seq_length
,
T
(
1.0
),
T
(
0.0
),
CUBLAS_OP_N
,
CUBLAS_OP_N
,
gemm_algos
[
4
]))
{
assert
(
_hidden_size
%
_heads
==
0
);
Initialize
();
}
template
<
typename
T
>
BertTransformerLayer
<
T
>::~
BertTransformerLayer
()
{
}
template
<
typename
T
>
void
BertTransformerLayer
<
T
>::
Initialize
()
{
if
(
std
::
is_same
<
T
,
__half
>::
value
)
cublasSetMathMode
(
_cublasHandle
,
CUBLAS_TENSOR_OP_MATH
);
}
template
<
typename
T
>
void
BertTransformerLayer
<
T
>::
Forward
(
int
bsz
,
const
T
*
input_ptr
,
const
T
*
input_mask_ptr
,
const
T
*
attn_qkvw_ptr
,
const
T
*
attn_qkvb_ptr
,
const
T
*
attn_ow_ptr
,
const
T
*
attn_ob_ptr
,
const
T
*
attn_nw_ptr
,
const
T
*
attn_nb_ptr
,
const
T
*
inter_w_ptr
,
const
T
*
inter_b_ptr
,
const
T
*
output_w_ptr
,
const
T
*
output_b_ptr
,
const
T
*
norm_w_ptr
,
const
T
*
norm_b_ptr
,
T
*
out_ptr
,
T
*
inp_norm_ptr
,
T
*
q_tf_ptr
,
T
*
k_tf_ptr
,
T
*
v_tf_ptr
,
T
*
soft_out_ptr
,
T
*
ctx_bufB_ptr
,
T
*
attn_o_inp_ptr
,
T
*
add_res_ptr
,
T
*
ff1_inp_ptr
,
T
*
gelu_inp_ptr
,
T
*
ff2_inp_ptr
)
{
cublasSetStream
(
_cublasHandle
,
_stream
);
if
(
!
_stochastic_mode
)
cudaStreamSynchronize
(
_stream
);
T
*
workspace
=
static_cast
<
T
*>
(
Context
::
Instance
().
GetWorkSpace
());
size_t
small_buf_size
=
bsz
*
_seq_length
*
_hidden_size
;
T
*
buf_0
=
workspace
;
T
*
buf_1
=
buf_0
+
small_buf_size
;
T
*
buf_2
=
buf_1
;
if
(
_normalize_invertible
)
{
add_res_ptr
=
buf_1
+
3
*
small_buf_size
;
buf_2
=
add_res_ptr
;
}
if
(
_gelu_checkpoint
)
buf_2
+=
small_buf_size
;
if
(
_attn_dropout_checkpoint
)
ctx_bufB_ptr
=
(
_gelu_checkpoint
?
(
buf_2
+
(
_intermediate_size
/
_hidden_size
)
*
small_buf_size
)
:
(
buf_1
+
4
*
small_buf_size
));
int
bsz_seq
=
bsz
*
_seq_length
;
if
(
_pre_or_postLayerNorm
)
{
if
(
_layer_norm
.
UseMean
())
_layer_norm
.
ForwardCheckpoint
(
bsz_seq
,
inp_norm_ptr
,
input_ptr
,
norm_w_ptr
,
norm_b_ptr
,
_stream
,
true
);
else
_layer_norm
.
Forward
(
bsz_seq
,
inp_norm_ptr
,
input_ptr
,
norm_w_ptr
,
norm_b_ptr
,
_stream
,
true
);
}
if
(
_pre_or_postLayerNorm
)
_qkv_linear
.
Forward
(
bsz_seq
,
inp_norm_ptr
,
attn_qkvw_ptr
,
buf_0
,
_cublasHandle
);
else
_qkv_linear
.
Forward
(
bsz_seq
,
input_ptr
,
attn_qkvw_ptr
,
buf_0
,
_cublasHandle
);
launch_bias_add_transform_0213
<
T
>
(
q_tf_ptr
,
buf_0
,
attn_qkvb_ptr
,
bsz
,
_seq_length
,
_hidden_size
,
_heads
,
_stream
,
3
);
int
bsz_heads
=
bsz
*
_heads
;
// attention scores
_attn_scores
.
Forward
(
bsz_heads
,
soft_out_ptr
,
k_tf_ptr
,
q_tf_ptr
,
_cublasHandle
);
// Softmax + Mask
_softmax
.
Forward
(
bsz
,
soft_out_ptr
,
input_mask_ptr
,
_stream
);
// attn prob dropout.
_attn_prob_dropout
.
Forward
(
bsz_heads
*
_seq_length
,
ctx_bufB_ptr
,
soft_out_ptr
,
_stream
);
// attention context
_attn_context
.
Forward
(
bsz_heads
,
buf_1
,
v_tf_ptr
,
ctx_bufB_ptr
,
_cublasHandle
);
launch_transform4d_0213
<
T
>
(
attn_o_inp_ptr
,
buf_1
,
bsz
,
_heads
,
_seq_length
,
_hidden_size
,
_stream
,
1
);
if
(
_pre_or_postLayerNorm
)
_attn_out_linear
.
Forward
(
bsz_seq
,
attn_o_inp_ptr
,
attn_ow_ptr
,
buf_1
,
_cublasHandle
);
else
_attn_out_linear
.
Forward
(
bsz_seq
,
attn_o_inp_ptr
,
attn_ow_ptr
,
ff1_inp_ptr
,
_cublasHandle
);
// attn output dropout.
if
(
_pre_or_postLayerNorm
)
_attn_output_dropout
.
ForwardWithBias
(
bsz_seq
,
add_res_ptr
,
buf_1
,
input_ptr
,
attn_ob_ptr
,
_stream
);
else
_attn_output_dropout
.
ForwardWithBias
(
bsz_seq
,
add_res_ptr
,
ff1_inp_ptr
,
input_ptr
,
attn_ob_ptr
,
_stream
);
if
(
_pre_or_postLayerNorm
)
{
if
(
_attn_layer_norm
.
UseMean
())
_attn_layer_norm
.
ForwardCheckpoint
(
bsz_seq
,
ff1_inp_ptr
,
add_res_ptr
,
attn_nw_ptr
,
attn_nb_ptr
,
_stream
,
true
);
else
_attn_layer_norm
.
Forward
(
bsz_seq
,
ff1_inp_ptr
,
add_res_ptr
,
attn_nw_ptr
,
attn_nb_ptr
,
_stream
,
true
);
}
else
{
if
(
_attn_layer_norm
.
UseMean
())
_attn_layer_norm
.
ForwardCheckpoint
(
bsz_seq
,
ff1_inp_ptr
,
add_res_ptr
,
attn_nw_ptr
,
attn_nb_ptr
,
_stream
,
true
);
else
_attn_layer_norm
.
Forward
(
bsz_seq
,
ff1_inp_ptr
,
add_res_ptr
,
attn_nw_ptr
,
attn_nb_ptr
,
_stream
,
true
);
}
_ff1
.
Forward
(
bsz_seq
,
ff1_inp_ptr
,
inter_w_ptr
,
(
_gelu_checkpoint
?
ff2_inp_ptr
:
gelu_inp_ptr
),
_cublasHandle
);
_gelu
.
ForwardWithBiasAdd
(
bsz_seq
,
(
_gelu_checkpoint
?
ff2_inp_ptr
:
gelu_inp_ptr
),
inter_b_ptr
,
(
_gelu_checkpoint
?
buf_2
:
ff2_inp_ptr
),
_stream
);
_ff2
.
Forward
(
bsz_seq
,
(
_gelu_checkpoint
?
buf_2
:
ff2_inp_ptr
),
output_w_ptr
,
out_ptr
,
_cublasHandle
);
// layer output dropout.
if
(
_pre_or_postLayerNorm
)
_layer_output_dropout
.
ForwardWithBias
(
bsz_seq
,
out_ptr
,
out_ptr
,
add_res_ptr
,
output_b_ptr
,
_stream
);
else
_layer_output_dropout
.
ForwardWithBias
(
bsz_seq
,
inp_norm_ptr
,
out_ptr
,
ff1_inp_ptr
,
output_b_ptr
,
_stream
);
if
(
!
_pre_or_postLayerNorm
)
{
if
(
_layer_norm
.
UseMean
())
_layer_norm
.
ForwardCheckpoint
(
bsz_seq
,
out_ptr
,
inp_norm_ptr
,
norm_w_ptr
,
norm_b_ptr
,
_stream
,
true
);
else
_layer_norm
.
Forward
(
bsz_seq
,
out_ptr
,
inp_norm_ptr
,
norm_w_ptr
,
norm_b_ptr
,
_stream
,
true
);
}
}
template
<
typename
T
>
void
BertTransformerLayer
<
T
>::
Backward
(
int
bsz
,
const
T
*
grad_output_ptr
,
const
T
*
input_ptr
,
const
T
*
output_ptr
,
const
T
*
inp_norm_ptr
,
const
T
*
q_tf_ptr
,
const
T
*
k_tf_ptr
,
const
T
*
v_tf_ptr
,
const
T
*
soft_out_ptr
,
const
T
*
ctx_bufB_ptr
,
const
T
*
attn_o_inp_ptr
,
const
T
*
add_res_ptr
,
const
T
*
ff1_inp_ptr
,
const
T
*
gelu_inp_ptr
,
const
T
*
ff2_inp_ptr
,
const
T
*
input_mask_ptr
,
const
T
*
attn_qkvw_ptr
,
const
T
*
attn_ow_ptr
,
const
T
*
attn_nw_ptr
,
const
T
*
attn_nb_ptr
,
const
T
*
inter_w_ptr
,
const
T
*
inter_b_ptr
,
const
T
*
output_w_ptr
,
const
T
*
norm_w_ptr
,
const
T
*
norm_b_ptr
,
T
*
grad_input_ptr
,
T
*
grad_attn_qkvw_ptr
,
T
*
grad_attn_qkvb_ptr
,
T
*
grad_attn_ow_ptr
,
T
*
grad_attn_ob_ptr
,
T
*
grad_attn_nw_ptr
,
T
*
grad_attn_nb_ptr
,
T
*
grad_inter_w_ptr
,
T
*
grad_inter_b_ptr
,
T
*
grad_output_w_ptr
,
T
*
grad_output_b_ptr
,
T
*
grad_norm_w_ptr
,
T
*
grad_norm_b_ptr
)
{
cublasSetStream
(
_cublasHandle
,
_stream
);
if
(
!
_stochastic_mode
)
cudaStreamSynchronize
(
_stream
);
T
*
workspace
=
static_cast
<
T
*>
(
Context
::
Instance
().
GetWorkSpace
());
size_t
small_buf_size
=
bsz
*
_seq_length
*
_hidden_size
;
T
*
buf_0
=
workspace
;
T
*
buf_1
=
buf_0
+
small_buf_size
;
T
*
buf_2
=
buf_1
+
small_buf_size
;
T
*
buf_3
=
buf_2
+
small_buf_size
;
T
*
ff2_buf
=
(
_gelu_checkpoint
?
buf_3
+
(
bsz
*
_seq_length
*
_intermediate_size
)
:
buf_3
+
small_buf_size
);
T
*
ctx_bufB_ptr_recomp
=
ff2_buf
+
(
_seq_length
*
_seq_length
*
bsz
*
_heads
);
cudaStream_t
streams
[
2
]
=
{
_stream
,
_stream
};
int
bsz_seq
=
bsz
*
_seq_length
;
int
bsz_heads
=
bsz
*
_heads
;
if
(
!
_pre_or_postLayerNorm
)
{
if
(
_layer_norm
.
UseMean
())
_layer_norm
.
Backward
(
bsz_seq
,
grad_output_ptr
,
norm_w_ptr
,
grad_norm_w_ptr
,
grad_norm_b_ptr
,
streams
,
buf_1
,
inp_norm_ptr
);
else
_layer_norm
.
Backward
(
bsz_seq
,
grad_output_ptr
,
norm_w_ptr
,
norm_b_ptr
,
grad_norm_w_ptr
,
grad_norm_b_ptr
,
streams
,
buf_1
,
output_ptr
);
}
if
(
_pre_or_postLayerNorm
)
_layer_output_dropout
.
Backward
(
bsz_seq
,
buf_0
,
grad_output_ptr
,
_stream
);
else
_layer_output_dropout
.
Backward
(
bsz_seq
,
buf_0
,
buf_1
,
_stream
);
const
T
*
layer_dropout_buf
=
_layer_output_dropout
.
HasDropout
()
?
buf_0
:
(
_pre_or_postLayerNorm
?
grad_output_ptr
:
buf_1
);
if
(
_gelu_checkpoint
)
_gelu
.
ForwardWithBiasAdd
(
bsz_seq
,
ff2_inp_ptr
,
inter_b_ptr
,
buf_2
,
_stream
);
_ff2
.
Backward
(
bsz_seq
,
layer_dropout_buf
,
(
_gelu_checkpoint
?
buf_2
:
ff2_inp_ptr
),
output_w_ptr
,
grad_output_w_ptr
,
grad_output_b_ptr
,
_cublasHandle
,
_stream
,
ff2_buf
);
_gelu
.
Backward
(
bsz_seq
,
ff2_buf
,
(
_gelu_checkpoint
?
ff2_inp_ptr
:
gelu_inp_ptr
),
inter_b_ptr
,
_stream
);
_ff1
.
Backward
(
bsz_seq
,
ff2_buf
,
ff1_inp_ptr
,
inter_w_ptr
,
grad_inter_w_ptr
,
grad_inter_b_ptr
,
_cublasHandle
,
_stream
,
buf_3
);
if
(
!
_pre_or_postLayerNorm
)
launch_fused_add2
<
T
>
(
buf_2
,
buf_3
,
buf_1
,
bsz
,
_seq_length
,
_hidden_size
,
_stream
);
if
(
_pre_or_postLayerNorm
)
{
if
(
_attn_layer_norm
.
UseMean
())
_attn_layer_norm
.
BackwardFusedAdd
(
bsz_seq
,
buf_3
,
grad_output_ptr
,
attn_nw_ptr
,
grad_attn_nw_ptr
,
grad_attn_nb_ptr
,
streams
,
buf_0
,
add_res_ptr
);
else
_attn_layer_norm
.
BackwardFusedAdd
(
bsz_seq
,
buf_3
,
grad_output_ptr
,
attn_nw_ptr
,
attn_nb_ptr
,
grad_attn_nw_ptr
,
grad_attn_nb_ptr
,
streams
,
buf_0
,
ff1_inp_ptr
);
}
else
{
if
(
_attn_layer_norm
.
UseMean
())
_attn_layer_norm
.
Backward
(
bsz_seq
,
buf_2
,
attn_nw_ptr
,
grad_attn_nw_ptr
,
grad_attn_nb_ptr
,
streams
,
buf_0
,
add_res_ptr
);
else
_attn_layer_norm
.
Backward
(
bsz_seq
,
buf_2
,
attn_nw_ptr
,
attn_nb_ptr
,
grad_attn_nw_ptr
,
grad_attn_nb_ptr
,
streams
,
buf_0
,
ff1_inp_ptr
);
}
_attn_output_dropout
.
Backward
(
bsz_seq
,
buf_2
,
buf_0
,
_stream
);
T
*
attn_output_dropout_buf
=
_attn_output_dropout
.
HasDropout
()
?
buf_2
:
buf_0
;
_attn_out_linear
.
Backward
(
bsz_seq
,
attn_output_dropout_buf
,
attn_o_inp_ptr
,
attn_ow_ptr
,
grad_attn_ow_ptr
,
grad_attn_ob_ptr
,
_cublasHandle
,
_stream
,
buf_1
);
launch_transform_0213
<
T
>
(
buf_2
,
buf_1
,
bsz
,
_seq_length
,
_hidden_size
,
_heads
,
_stream
);
if
(
_attn_prob_dropout
.
HasDropout
())
{
if
(
_attn_dropout_checkpoint
)
_attn_prob_dropout
.
Forward
(
bsz_heads
*
_seq_length
,
ctx_bufB_ptr_recomp
,
soft_out_ptr
,
_stream
,
true
);
_attn_context
.
Backward
(
bsz_heads
,
buf_2
,
v_tf_ptr
,
(
_attn_dropout_checkpoint
?
ctx_bufB_ptr_recomp
:
ctx_bufB_ptr
),
_cublasHandle
,
buf_3
,
ff2_buf
);
}
else
_attn_context
.
Backward
(
bsz_heads
,
buf_2
,
v_tf_ptr
,
soft_out_ptr
,
_cublasHandle
,
buf_3
,
ff2_buf
);
_attn_prob_dropout
.
Backward
(
bsz_heads
*
_seq_length
,
ff2_buf
,
_stream
);
_softmax
.
Backward
(
bsz
,
ff2_buf
,
soft_out_ptr
,
_stream
);
_attn_scores
.
Backward
(
bsz_heads
,
ff2_buf
,
k_tf_ptr
,
q_tf_ptr
,
_cublasHandle
,
buf_2
,
buf_1
);
launch_transform4d_0213
(
ff2_buf
,
buf_1
,
bsz
,
_heads
,
_seq_length
,
_hidden_size
,
_stream
,
3
);
if
(
_pre_or_postLayerNorm
)
_qkv_linear
.
Backward
(
bsz_seq
,
ff2_buf
,
inp_norm_ptr
,
attn_qkvw_ptr
,
grad_attn_qkvw_ptr
,
grad_attn_qkvb_ptr
,
_cublasHandle
,
_stream
,
buf_2
);
else
_qkv_linear
.
Backward
(
bsz_seq
,
ff2_buf
,
input_ptr
,
attn_qkvw_ptr
,
grad_attn_qkvw_ptr
,
grad_attn_qkvb_ptr
,
_cublasHandle
,
_stream
,
buf_2
);
if
(
_pre_or_postLayerNorm
)
{
if
(
_layer_norm
.
UseMean
())
_layer_norm
.
BackwardFusedAdd
(
bsz_seq
,
buf_2
,
buf_0
,
norm_w_ptr
,
grad_norm_w_ptr
,
grad_norm_b_ptr
,
streams
,
grad_input_ptr
,
input_ptr
);
else
_layer_norm
.
BackwardFusedAdd
(
bsz_seq
,
buf_2
,
buf_0
,
norm_w_ptr
,
norm_b_ptr
,
grad_norm_w_ptr
,
grad_norm_b_ptr
,
streams
,
grad_input_ptr
,
inp_norm_ptr
);
}
else
launch_fused_add2
<
T
>
(
grad_input_ptr
,
buf_2
,
buf_0
,
bsz
,
_seq_length
,
_hidden_size
,
_stream
);
}
template
<
typename
T
>
void
BertTransformerLayer
<
T
>::
SetTrainingMode
(
bool
training
)
{
// Dropout will be skipped when not in training model.
_attn_prob_dropout
.
SetTrainingMode
(
training
);
_attn_output_dropout
.
SetTrainingMode
(
training
);
_layer_output_dropout
.
SetTrainingMode
(
training
);
}
template
<
typename
T
>
void
BertTransformerLayer
<
T
>::
SetIntermediateBuffers
(
uint8_t
*
attn_prob_dropout_mask_ptr
,
uint8_t
*
attn_output_dropout_mask_ptr
,
uint8_t
*
layer_output_dropout_mask_ptr
,
T
*
attn_layer_norm_var
,
T
*
attn_layer_norm_mean
,
T
*
layer_norm_var
,
T
*
layer_norm_mean
)
{
_attn_prob_dropout
.
SetMask
(
attn_prob_dropout_mask_ptr
);
_attn_output_dropout
.
SetMask
(
attn_output_dropout_mask_ptr
);
_layer_output_dropout
.
SetMask
(
layer_output_dropout_mask_ptr
);
_attn_layer_norm
.
SetVar
(
attn_layer_norm_var
);
_attn_layer_norm
.
SetMean
(
attn_layer_norm_mean
);
_layer_norm
.
SetVar
(
layer_norm_var
);
_layer_norm
.
SetMean
(
layer_norm_mean
);
}
template
<
typename
T
>
void
BertTransformerLayer
<
T
>::
SetSeqLength
(
int
seq_len
)
{
_seq_length
=
seq_len
;
_softmax
.
SetSeqLength
(
_seq_length
);
_attn_prob_dropout
.
SetDimension
(
_seq_length
);
_attn_scores
.
SetConfig
(
_seq_length
,
_seq_length
,
_hidden_size
/
_heads
);
_attn_context
.
SetConfig
(
_hidden_size
/
_heads
,
_seq_length
,
_seq_length
);
}
template
<
typename
T
>
int
create_transformer_layer
(
int
layer_id
,
int
batch_size
,
int
hidden_dim
,
int
num_heads
,
int
intermediate_size
,
float
attn_dropout_ratio
,
float
hidden_dropout_ratio
,
float
layer_norm_eps
,
int
seed
,
bool
pre_or_postLayerNorm
,
bool
test_gemm
,
bool
attn_dropout_checkpoint
,
bool
normalize_invertible
,
bool
gelu_checkpoint
,
bool
stochastic_mode
)
{
Context
::
Instance
().
SetSeed
(
seed
);
Context
::
Instance
().
TestGemmFP16
(
test_gemm
,
batch_size
,
init_seq_length
,
num_heads
,
hidden_dim
/
num_heads
);
auto
layer
=
std
::
make_shared
<
BertTransformerLayer
<
T
>>
(
layer_id
,
batch_size
,
hidden_dim
,
num_heads
,
intermediate_size
,
init_seq_length
,
attn_dropout_ratio
,
hidden_dropout_ratio
,
layer_norm_eps
,
pre_or_postLayerNorm
,
Context
::
Instance
().
GetGemmAlgos
(),
attn_dropout_checkpoint
,
normalize_invertible
,
gelu_checkpoint
,
stochastic_mode
);
s_transformer_layers
[
layer_id
]
=
layer
;
std
::
string
dtype
=
(
std
::
is_same
<
T
,
__half
>::
value
)
?
"half"
:
"float"
;
std
::
cout
<<
"layer #"
<<
layer_id
<<
" is created with date type ["
<<
dtype
<<
"]."
<<
std
::
endl
;
return
0
;
}
template
<
typename
T
>
std
::
vector
<
torch
::
Tensor
>
ds_transformer_forward
(
int
layer_id
,
const
torch
::
Tensor
&
input
,
const
torch
::
Tensor
&
input_mask
,
const
torch
::
Tensor
&
attn_qkvw
,
const
torch
::
Tensor
&
attn_qkvb
,
const
torch
::
Tensor
&
attn_ow
,
const
torch
::
Tensor
&
attn_ob
,
const
torch
::
Tensor
&
attn_nw
,
const
torch
::
Tensor
&
attn_nb
,
const
torch
::
Tensor
&
inter_w
,
const
torch
::
Tensor
&
inter_b
,
const
torch
::
Tensor
&
output_w
,
const
torch
::
Tensor
&
output_b
,
const
torch
::
Tensor
&
norm_w
,
const
torch
::
Tensor
&
norm_b
,
bool
training_mode
,
bool
prelayernorm
,
bool
attn_dropout_checkpoint
,
bool
normalize_invertible
,
bool
gelu_checkpoint
)
{
CHECK_INPUT
(
input
);
CHECK_INPUT
(
input_mask
);
CHECK_INPUT
(
attn_qkvw
);
CHECK_INPUT
(
attn_qkvb
);
CHECK_INPUT
(
attn_ow
);
CHECK_INPUT
(
attn_ob
);
CHECK_INPUT
(
attn_nw
);
CHECK_INPUT
(
attn_nb
);
CHECK_INPUT
(
inter_w
);
CHECK_INPUT
(
inter_b
);
CHECK_INPUT
(
output_w
);
CHECK_INPUT
(
output_b
);
CHECK_INPUT
(
norm_w
);
CHECK_INPUT
(
norm_b
);
int
bsz
=
input
.
size
(
0
);
const
T
*
input_ptr
=
(
const
T
*
)
input
.
data_ptr
();
const
T
*
input_mask_ptr
=
(
const
T
*
)
input_mask
.
data_ptr
();
const
T
*
attn_qkvw_ptr
=
(
const
T
*
)
attn_qkvw
.
data_ptr
();
const
T
*
attn_qkvb_ptr
=
(
const
T
*
)
attn_qkvb
.
data_ptr
();
const
T
*
attn_ow_ptr
=
(
const
T
*
)
attn_ow
.
data_ptr
();
const
T
*
attn_ob_ptr
=
(
const
T
*
)
attn_ob
.
data_ptr
();
const
T
*
attn_nw_ptr
=
(
const
T
*
)
attn_nw
.
data_ptr
();
const
T
*
attn_nb_ptr
=
(
const
T
*
)
attn_nb
.
data_ptr
();
const
T
*
inter_w_ptr
=
(
const
T
*
)
inter_w
.
data_ptr
();
const
T
*
inter_b_ptr
=
(
const
T
*
)
inter_b
.
data_ptr
();
const
T
*
output_w_ptr
=
(
const
T
*
)
output_w
.
data_ptr
();
const
T
*
output_b_ptr
=
(
const
T
*
)
output_b
.
data_ptr
();
const
T
*
norm_w_ptr
=
(
const
T
*
)
norm_w
.
data_ptr
();
const
T
*
norm_b_ptr
=
(
const
T
*
)
norm_b
.
data_ptr
();
auto
output
=
torch
::
empty_like
(
input
);
T
*
out_ptr
=
(
T
*
)
output
.
data_ptr
();
auto
options
=
torch
::
TensorOptions
()
.
dtype
(
input
.
options
().
dtype
())
.
layout
(
torch
::
kStrided
)
.
device
(
torch
::
kCUDA
)
.
requires_grad
(
true
);
auto
uint8_options
=
torch
::
TensorOptions
()
.
dtype
(
torch
::
kInt8
)
.
layout
(
torch
::
kStrided
)
.
device
(
torch
::
kCUDA
)
.
requires_grad
(
false
);
std
::
shared_ptr
<
BertTransformerLayer
<
T
>>
layer
=
std
::
static_pointer_cast
<
BertTransformerLayer
<
T
>>
(
s_transformer_layers
[
layer_id
]);
int
seq_len
=
layer
->
GetSeqLength
();
if
(
input
.
size
(
1
)
!=
seq_len
)
{
seq_len
=
input
.
size
(
1
);
layer
->
SetSeqLength
(
seq_len
);
}
auto
workspace
=
torch
::
empty
({
get_workspace_size
<
T
>
(
bsz
,
seq_len
,
layer
->
GetHiddenSize
(),
layer
->
GetIntermediateSize
(),
layer
->
GetNumHeads
(),
layer
->
IsTrainingMode
(),
layer
->
GeluCheckpoint
())},
options
);
Context
::
Instance
().
SetWorkSpace
((
T
*
)
workspace
.
data_ptr
());
auto
inp_norm
=
((
prelayernorm
||
!
normalize_invertible
)
?
torch
::
empty_like
(
input
)
:
output
);
auto
add_res
=
(
normalize_invertible
?
inp_norm
:
torch
::
empty_like
(
input
));
auto
attn_o_inp
=
torch
::
empty_like
(
input
);
auto
qkv_tf
=
torch
::
empty
({(
bsz
*
seq_len
),
output_w
.
size
(
0
)
*
3
},
options
);
auto
attn_prob_dropout_mask
=
torch
::
empty
({(
bsz
*
layer
->
GetNumHeads
()
*
seq_len
),
seq_len
},
uint8_options
);
auto
attn_output_dropout_mask
=
torch
::
empty
({(
bsz
*
seq_len
),
layer
->
GetHiddenSize
()},
uint8_options
);
auto
layer_output_dropout_mask
=
torch
::
empty
({(
bsz
*
seq_len
),
layer
->
GetHiddenSize
()},
uint8_options
);
auto
attn_layer_norm_var
=
torch
::
empty
({(
bsz
*
seq_len
)},
options
);
auto
attn_layer_norm_mean
=
torch
::
empty
({(
bsz
*
seq_len
)},
options
);
auto
layer_norm_var
=
torch
::
empty
({(
bsz
*
seq_len
)},
options
);
auto
layer_norm_mean
=
torch
::
empty
({(
bsz
*
seq_len
)},
options
);
T
*
inp_norm_ptr
=
(
T
*
)
inp_norm
.
data_ptr
();
T
*
add_res_ptr
=
(
T
*
)
add_res
.
data_ptr
();
T
*
q_tf_ptr
=
(
T
*
)
qkv_tf
.
data_ptr
();
T
*
k_tf_ptr
=
q_tf_ptr
+
(
bsz
*
seq_len
*
output_w
.
size
(
0
));
//(T*)k_tf.data_ptr();
T
*
v_tf_ptr
=
k_tf_ptr
+
(
bsz
*
seq_len
*
output_w
.
size
(
0
));
//(T*)v_tf.data_ptr();
T
*
attn_o_inp_ptr
=
(
T
*
)
attn_o_inp
.
data_ptr
();
torch
::
Tensor
ff2_inp
=
torch
::
empty
({(
bsz
*
seq_len
),
output_w
.
size
(
1
)},
options
);
torch
::
Tensor
gelu_inp
=
(
gelu_checkpoint
?
ff2_inp
:
torch
::
empty
({(
bsz
*
seq_len
),
output_w
.
size
(
1
)},
options
));
auto
ff1_inp
=
torch
::
empty_like
(
input
);
T
*
ff2_inp_ptr
=
(
T
*
)
ff2_inp
.
data_ptr
();
T
*
gelu_inp_ptr
=
(
T
*
)
gelu_inp
.
data_ptr
();
T
*
ff1_inp_ptr
=
(
T
*
)
ff1_inp
.
data_ptr
();
torch
::
Tensor
soft_out
=
torch
::
empty
({(
bsz
*
layer
->
GetNumHeads
()
*
seq_len
),
seq_len
},
options
);
torch
::
Tensor
ctx_bufB
=
(
attn_dropout_checkpoint
?
soft_out
:
torch
::
empty
({(
bsz
*
layer
->
GetNumHeads
()
*
seq_len
),
seq_len
},
options
));
T
*
soft_out_ptr
=
(
T
*
)
soft_out
.
data_ptr
();
T
*
ctx_bufB_ptr
=
(
T
*
)
ctx_bufB
.
data_ptr
();
layer
->
SetTrainingMode
(
training_mode
);
layer
->
SetIntermediateBuffers
((
uint8_t
*
)
attn_prob_dropout_mask
.
data_ptr
(),
(
uint8_t
*
)
attn_output_dropout_mask
.
data_ptr
(),
(
uint8_t
*
)
layer_output_dropout_mask
.
data_ptr
(),
(
T
*
)
attn_layer_norm_var
.
data_ptr
(),
(
T
*
)
attn_layer_norm_mean
.
data_ptr
(),
(
T
*
)
layer_norm_var
.
data_ptr
(),
(
T
*
)
layer_norm_mean
.
data_ptr
());
layer
->
Forward
(
bsz
,
input_ptr
,
input_mask_ptr
,
attn_qkvw_ptr
,
attn_qkvb_ptr
,
attn_ow_ptr
,
attn_ob_ptr
,
attn_nw_ptr
,
attn_nb_ptr
,
inter_w_ptr
,
inter_b_ptr
,
output_w_ptr
,
output_b_ptr
,
norm_w_ptr
,
norm_b_ptr
,
out_ptr
,
inp_norm_ptr
,
q_tf_ptr
,
k_tf_ptr
,
v_tf_ptr
,
soft_out_ptr
,
ctx_bufB_ptr
,
attn_o_inp_ptr
,
add_res_ptr
,
ff1_inp_ptr
,
gelu_inp_ptr
,
ff2_inp_ptr
);
return
{
output
,
inp_norm
,
qkv_tf
,
soft_out
,
ctx_bufB
,
attn_o_inp
,
add_res
,
ff1_inp
,
gelu_inp
,
ff2_inp
,
attn_prob_dropout_mask
,
attn_output_dropout_mask
,
layer_output_dropout_mask
,
attn_layer_norm_var
,
attn_layer_norm_mean
,
layer_norm_var
,
layer_norm_mean
};
}
template
<
typename
T
>
std
::
vector
<
torch
::
Tensor
>
ds_transformer_backward
(
int
layer_id
,
const
torch
::
Tensor
&
grad_output
,
const
torch
::
Tensor
&
output
,
const
torch
::
Tensor
&
inp_norm
,
const
torch
::
Tensor
&
qkv_tf
,
const
torch
::
Tensor
&
soft_out
,
const
torch
::
Tensor
&
ctx_bufB
,
const
torch
::
Tensor
&
attn_o_inp
,
const
torch
::
Tensor
&
add_res
,
const
torch
::
Tensor
&
ff1_inp
,
const
torch
::
Tensor
&
gelu_inp
,
const
torch
::
Tensor
&
ff2_inp
,
const
torch
::
Tensor
&
attn_prob_dropout_mask
,
const
torch
::
Tensor
&
attn_output_dropout_mask
,
const
torch
::
Tensor
&
layer_output_dropout_mask
,
const
torch
::
Tensor
&
attn_layer_norm_var
,
const
torch
::
Tensor
&
attn_layer_norm_mean
,
const
torch
::
Tensor
&
layer_norm_var
,
const
torch
::
Tensor
&
layer_norm_mean
,
const
torch
::
Tensor
&
input
,
const
torch
::
Tensor
&
input_mask
,
const
torch
::
Tensor
&
attn_qkvw
,
const
torch
::
Tensor
&
attn_qkvb
,
const
torch
::
Tensor
&
attn_ow
,
const
torch
::
Tensor
&
attn_ob
,
const
torch
::
Tensor
&
attn_nw
,
const
torch
::
Tensor
&
attn_nb
,
const
torch
::
Tensor
&
inter_w
,
const
torch
::
Tensor
&
inter_b
,
const
torch
::
Tensor
&
output_w
,
const
torch
::
Tensor
&
output_b
,
const
torch
::
Tensor
&
norm_w
,
const
torch
::
Tensor
&
norm_b
)
{
auto
g_output
=
grad_output
.
contiguous
();
CHECK_INPUT
(
g_output
);
CHECK_INPUT
(
output
);
CHECK_INPUT
(
inp_norm
);
CHECK_INPUT
(
qkv_tf
);
CHECK_INPUT
(
add_res
);
CHECK_INPUT
(
soft_out
);
CHECK_INPUT
(
ctx_bufB
);
CHECK_INPUT
(
attn_o_inp
);
CHECK_INPUT
(
ff1_inp
);
CHECK_INPUT
(
gelu_inp
);
CHECK_INPUT
(
ff2_inp
);
CHECK_INPUT
(
input
);
CHECK_INPUT
(
input_mask
);
CHECK_INPUT
(
attn_qkvw
);
CHECK_INPUT
(
attn_qkvb
);
CHECK_INPUT
(
attn_ow
);
CHECK_INPUT
(
attn_ob
);
CHECK_INPUT
(
attn_nw
);
CHECK_INPUT
(
attn_nb
);
CHECK_INPUT
(
inter_w
);
CHECK_INPUT
(
inter_b
);
CHECK_INPUT
(
output_w
);
CHECK_INPUT
(
output_b
);
CHECK_INPUT
(
norm_w
);
CHECK_INPUT
(
norm_b
);
int
bsz
=
g_output
.
size
(
0
);
std
::
shared_ptr
<
BertTransformerLayer
<
T
>>
layer
=
std
::
static_pointer_cast
<
BertTransformerLayer
<
T
>>
(
s_transformer_layers
[
layer_id
]);
int
seq_len
=
layer
->
GetSeqLength
();
if
(
g_output
.
size
(
1
)
!=
seq_len
)
{
seq_len
=
g_output
.
size
(
1
);
layer
->
SetSeqLength
(
seq_len
);
}
auto
options
=
torch
::
TensorOptions
()
.
dtype
(
g_output
.
options
().
dtype
())
.
layout
(
torch
::
kStrided
)
.
device
(
torch
::
kCUDA
)
.
requires_grad
(
true
);
auto
workspace
=
torch
::
empty
({
get_workspace_size
<
T
>
(
bsz
,
seq_len
,
layer
->
GetHiddenSize
(),
layer
->
GetIntermediateSize
(),
layer
->
GetNumHeads
(),
layer
->
IsTrainingMode
(),
layer
->
GeluCheckpoint
())},
options
);
Context
::
Instance
().
SetWorkSpace
((
T
*
)
workspace
.
data_ptr
());
auto
grad_input
=
torch
::
empty_like
(
input
);
auto
grad_attn_qkvw
=
torch
::
empty_like
(
attn_qkvw
);
auto
grad_attn_qkvb
=
torch
::
empty_like
(
attn_qkvb
);
auto
grad_attn_ow
=
torch
::
empty_like
(
attn_ow
);
auto
grad_attn_ob
=
torch
::
empty_like
(
attn_ob
);
auto
grad_attn_nw
=
torch
::
empty_like
(
attn_nw
);
auto
grad_attn_nb
=
torch
::
empty_like
(
attn_nb
);
auto
grad_inter_w
=
torch
::
empty_like
(
inter_w
);
auto
grad_inter_b
=
torch
::
empty_like
(
inter_b
);
auto
grad_output_w
=
torch
::
empty_like
(
output_w
);
auto
grad_output_b
=
torch
::
empty_like
(
output_b
);
auto
grad_norm_w
=
torch
::
empty_like
(
norm_w
);
auto
grad_norm_b
=
torch
::
empty_like
(
norm_b
);
// inputs.
const
T
*
grad_output_ptr
=
(
const
T
*
)
g_output
.
data_ptr
();
const
T
*
input_ptr
=
(
const
T
*
)
input
.
data_ptr
();
const
T
*
output_ptr
=
(
const
T
*
)
output
.
data_ptr
();
const
T
*
inp_norm_ptr
=
(
const
T
*
)
inp_norm
.
data_ptr
();
const
T
*
q_tf_ptr
=
(
const
T
*
)
qkv_tf
.
data_ptr
();
const
T
*
add_res_ptr
=
(
const
T
*
)
add_res
.
data_ptr
();
const
T
*
k_tf_ptr
=
q_tf_ptr
+
(
bsz
*
layer
->
GetSeqLength
()
*
output_w
.
size
(
0
));
//(const T*)k_tf.data_ptr();
const
T
*
v_tf_ptr
=
k_tf_ptr
+
(
bsz
*
layer
->
GetSeqLength
()
*
output_w
.
size
(
0
));
//(const T*)v_tf.data_ptr();
const
T
*
ff1_inp_ptr
=
(
const
T
*
)
ff1_inp
.
data_ptr
();
const
T
*
gelu_inp_ptr
=
(
const
T
*
)
gelu_inp
.
data_ptr
();
const
T
*
ff2_inp_ptr
=
(
const
T
*
)
ff2_inp
.
data_ptr
();
const
T
*
ctx_bufB_ptr
=
(
const
T
*
)
ctx_bufB
.
data_ptr
();
const
T
*
soft_out_ptr
=
(
const
T
*
)
soft_out
.
data_ptr
();
const
T
*
attn_o_inp_ptr
=
(
const
T
*
)
attn_o_inp
.
data_ptr
();
const
T
*
input_mask_ptr
=
(
const
T
*
)
input_mask
.
data_ptr
();
const
T
*
attn_qkvw_ptr
=
(
const
T
*
)
attn_qkvw
.
data_ptr
();
const
T
*
attn_ow_ptr
=
(
const
T
*
)
attn_ow
.
data_ptr
();
const
T
*
attn_nw_ptr
=
(
const
T
*
)
attn_nw
.
data_ptr
();
const
T
*
attn_nb_ptr
=
(
const
T
*
)
attn_nb
.
data_ptr
();
const
T
*
inter_w_ptr
=
(
const
T
*
)
inter_w
.
data_ptr
();
const
T
*
inter_b_ptr
=
(
const
T
*
)
inter_b
.
data_ptr
();
const
T
*
output_w_ptr
=
(
const
T
*
)
output_w
.
data_ptr
();
const
T
*
norm_w_ptr
=
(
const
T
*
)
norm_w
.
data_ptr
();
const
T
*
norm_b_ptr
=
(
const
T
*
)
norm_b
.
data_ptr
();
// outputs.
T
*
grad_input_ptr
=
(
T
*
)
grad_input
.
data_ptr
();
T
*
grad_attn_qkvw_ptr
=
(
T
*
)
grad_attn_qkvw
.
data_ptr
();
T
*
grad_attn_qkvb_ptr
=
(
T
*
)
grad_attn_qkvb
.
data_ptr
();
T
*
grad_attn_ow_ptr
=
(
T
*
)
grad_attn_ow
.
data_ptr
();
T
*
grad_attn_ob_ptr
=
(
T
*
)
grad_attn_ob
.
data_ptr
();
T
*
grad_attn_nw_ptr
=
(
T
*
)
grad_attn_nw
.
data_ptr
();
T
*
grad_attn_nb_ptr
=
(
T
*
)
grad_attn_nb
.
data_ptr
();
T
*
grad_inter_w_ptr
=
(
T
*
)
grad_inter_w
.
data_ptr
();
T
*
grad_inter_b_ptr
=
(
T
*
)
grad_inter_b
.
data_ptr
();
T
*
grad_output_w_ptr
=
(
T
*
)
grad_output_w
.
data_ptr
();
T
*
grad_output_b_ptr
=
(
T
*
)
grad_output_b
.
data_ptr
();
T
*
grad_norm_w_ptr
=
(
T
*
)
grad_norm_w
.
data_ptr
();
T
*
grad_norm_b_ptr
=
(
T
*
)
grad_norm_b
.
data_ptr
();
layer
->
SetIntermediateBuffers
((
uint8_t
*
)
attn_prob_dropout_mask
.
data_ptr
(),
(
uint8_t
*
)
attn_output_dropout_mask
.
data_ptr
(),
(
uint8_t
*
)
layer_output_dropout_mask
.
data_ptr
(),
(
T
*
)
attn_layer_norm_var
.
data_ptr
(),
(
T
*
)
attn_layer_norm_mean
.
data_ptr
(),
(
T
*
)
layer_norm_var
.
data_ptr
(),
(
T
*
)
layer_norm_mean
.
data_ptr
());
layer
->
Backward
(
bsz
,
grad_output_ptr
,
input_ptr
,
output_ptr
,
inp_norm_ptr
,
q_tf_ptr
,
k_tf_ptr
,
v_tf_ptr
,
soft_out_ptr
,
ctx_bufB_ptr
,
attn_o_inp_ptr
,
add_res_ptr
,
ff1_inp_ptr
,
gelu_inp_ptr
,
ff2_inp_ptr
,
input_mask_ptr
,
attn_qkvw_ptr
,
attn_ow_ptr
,
attn_nw_ptr
,
attn_nb_ptr
,
inter_w_ptr
,
inter_b_ptr
,
output_w_ptr
,
norm_w_ptr
,
norm_b_ptr
,
grad_input_ptr
,
grad_attn_qkvw_ptr
,
grad_attn_qkvb_ptr
,
grad_attn_ow_ptr
,
grad_attn_ob_ptr
,
grad_attn_nw_ptr
,
grad_attn_nb_ptr
,
grad_inter_w_ptr
,
grad_inter_b_ptr
,
grad_output_w_ptr
,
grad_output_b_ptr
,
grad_norm_w_ptr
,
grad_norm_b_ptr
);
return
{
grad_input
,
grad_attn_qkvw
,
grad_attn_qkvb
,
grad_attn_ow
,
grad_attn_ob
,
grad_attn_nw
,
grad_attn_nb
,
grad_inter_w
,
grad_inter_b
,
grad_output_w
,
grad_output_b
,
grad_norm_w
,
grad_norm_b
};
}
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"forward_fp32"
,
&
ds_transformer_forward
<
float
>
,
"DeepSpeed Transformer forward with fp32 (CUDA)"
);
m
.
def
(
"forward_fp16"
,
&
ds_transformer_forward
<
__half
>
,
"DeepSpeed Transformer forward with fp16 (CUDA)"
);
m
.
def
(
"backward_fp32"
,
&
ds_transformer_backward
<
float
>
,
"DeepSpeed Transformer backward with fp32 (CUDA)"
);
m
.
def
(
"backward_fp16"
,
&
ds_transformer_backward
<
__half
>
,
"DeepSpeed Transformer backward with fp16 (CUDA)"
);
m
.
def
(
"create_transformer_layer_fp32"
,
&
create_transformer_layer
<
float
>
,
"Create DeepSpeed Transformer Transformer Layer with fp32 (CUDA)"
);
m
.
def
(
"create_transformer_layer_fp16"
,
&
create_transformer_layer
<
__half
>
,
"Create DeepSpeed Transformer Transformer Layer with fp16 (CUDA)"
);
}
deepspeed/ops/csrc/transformer/gelu_kernels.cu
0 → 100644
View file @
eadbbe09
#include "custom_cuda_layers.h"
inline
__device__
float
gelu
(
const
float
x
)
{
const
float
sqrt_param
=
0.79788456080286535587989211986876
f
;
const
float
mul_param
=
0.044715
;
return
x
*
0.5
f
*
(
1.0
f
+
tanhf
(
sqrt_param
*
(
x
+
mul_param
*
x
*
x
*
x
)));
}
inline
__device__
float
d_gelu
(
const
float
x
)
{
const
float
sqrt_param
=
0.79788456080286535587989211986876
f
;
const
float
mul_param
=
0.044715
;
float
x2mul
=
x
*
x
*
mul_param
;
float
tan_h
=
tanhf
(
sqrt_param
*
(
x
+
x
*
x2mul
));
float
dg1
=
0.5
f
*
(
1.0
f
+
tan_h
);
float
dg2
=
x
*
0.5
f
*
sqrt_param
*
(
1
-
tan_h
*
tan_h
);
float
dg3
=
dg2
*
3
*
x2mul
;
return
(
dg1
+
dg2
+
dg3
);
}
/*
Fused bias add with GELU
Loads a vector of 4 elements each iteration, for stride
iterations. It was written with the intention to launch 256 thread
threadblocks, so to launch for bert-large, we would set ITERATIONS
to 4. This is currently done automatically as a heuristic, setting
the number of iterations as blocks of 1024.
For FP16, the values are loaded from memory as __half, but converted
to FP32 for the arithmetic itself, to prevent numerous overflow on
the intermediate hyperbolic tangent, since there's no intrinsic
that computes it directly.
*/
__global__
void
gelu_kernel
(
const
float
*
input
,
float
*
vals
,
int
intermediate_size
)
{
int
row
=
blockIdx
.
x
;
int
id
=
threadIdx
.
x
;
int
loop_stride
=
blockDim
.
x
;
int
iterations
=
intermediate_size
/
blockDim
.
x
/
4
;
int
row_stride
=
intermediate_size
/
4
;
const
float4
*
input_cast
=
reinterpret_cast
<
const
float4
*>
(
input
);
float4
*
vals_cast
=
reinterpret_cast
<
float4
*>
(
vals
);
for
(
int
i
=
0
;
i
<
iterations
;
i
++
)
{
if
(
i
*
loop_stride
+
id
<
row_stride
)
{
float4
data
=
input_cast
[
row
*
row_stride
+
i
*
loop_stride
+
id
];
data
.
x
=
gelu
(
data
.
x
);
data
.
y
=
gelu
(
data
.
y
);
data
.
z
=
gelu
(
data
.
z
);
data
.
w
=
gelu
(
data
.
w
);
vals_cast
[
row
*
row_stride
+
i
*
loop_stride
+
id
]
=
data
;
}
}
}
__global__
void
gelu_kernel
(
const
__half
*
input
,
__half
*
vals
,
int
intermediate_size
)
{
#if __CUDA_ARCH__ >= 700
int
row
=
blockIdx
.
x
;
int
id
=
threadIdx
.
x
;
int
loop_stride
=
blockDim
.
x
;
int
iterations
=
intermediate_size
/
blockDim
.
x
/
4
;
int
row_stride
=
intermediate_size
/
4
;
const
float2
*
input_cast
=
reinterpret_cast
<
const
float2
*>
(
input
);
float2
*
vals_cast
=
reinterpret_cast
<
float2
*>
(
vals
);
for
(
int
i
=
0
;
i
<
iterations
;
i
++
)
{
if
(
i
*
loop_stride
+
id
<
row_stride
)
{
float2
vals_vec
=
input_cast
[
row
*
row_stride
+
i
*
loop_stride
+
id
];
__half2
*
vals_half
=
reinterpret_cast
<
__half2
*>
(
&
vals_vec
);
float2
low_data
=
__half22float2
(
vals_half
[
0
]);
float2
high_data
=
__half22float2
(
vals_half
[
1
]);
low_data
.
x
=
gelu
(
low_data
.
x
);
low_data
.
y
=
gelu
(
low_data
.
y
);
high_data
.
x
=
gelu
(
high_data
.
x
);
high_data
.
y
=
gelu
(
high_data
.
y
);
vals_half
[
0
]
=
__float22half2_rn
(
low_data
);
vals_half
[
1
]
=
__float22half2_rn
(
high_data
);
vals_cast
[
row
*
row_stride
+
i
*
loop_stride
+
id
]
=
vals_vec
;
}
}
#endif
}
__global__
void
fused_bias_gelu
(
const
float
*
input
,
const
float
*
bias
,
float
*
vals
,
int
intermediate_size
)
{
int
row
=
blockIdx
.
x
;
int
id
=
threadIdx
.
x
;
int
loop_stride
=
blockDim
.
x
;
int
iterations
=
intermediate_size
/
blockDim
.
x
/
4
;
int
row_stride
=
intermediate_size
/
4
;
const
float4
*
input_cast
=
reinterpret_cast
<
const
float4
*>
(
input
);
float4
*
vals_cast
=
reinterpret_cast
<
float4
*>
(
vals
);
const
float4
*
bias_cast
=
reinterpret_cast
<
const
float4
*>
(
bias
);
for
(
int
i
=
0
;
i
<
iterations
;
i
++
)
{
if
(
i
*
loop_stride
+
id
<
row_stride
)
{
float4
data
=
input_cast
[
row
*
row_stride
+
i
*
loop_stride
+
id
];
float4
bias_data
=
bias_cast
[
i
*
loop_stride
+
id
];
data
.
x
+=
bias_data
.
x
;
data
.
y
+=
bias_data
.
y
;
data
.
z
+=
bias_data
.
z
;
data
.
w
+=
bias_data
.
w
;
data
.
x
=
gelu
(
data
.
x
);
data
.
y
=
gelu
(
data
.
y
);
data
.
z
=
gelu
(
data
.
z
);
data
.
w
=
gelu
(
data
.
w
);
vals_cast
[
row
*
row_stride
+
i
*
loop_stride
+
id
]
=
data
;
}
}
}
__global__
void
fused_bias_gelu
(
const
__half
*
input
,
const
__half
*
bias
,
__half
*
vals
,
int
intermediate_size
)
{
#if __CUDA_ARCH__ >= 700
int
row
=
blockIdx
.
x
;
int
id
=
threadIdx
.
x
;
int
loop_stride
=
blockDim
.
x
;
int
iterations
=
intermediate_size
/
blockDim
.
x
/
4
;
int
row_stride
=
intermediate_size
/
4
;
const
float2
*
input_cast
=
reinterpret_cast
<
const
float2
*>
(
input
);
float2
*
vals_cast
=
reinterpret_cast
<
float2
*>
(
vals
);
const
float2
*
bias_cast
=
reinterpret_cast
<
const
float2
*>
(
bias
);
for
(
int
i
=
0
;
i
<
iterations
;
i
++
)
{
if
(
i
*
loop_stride
+
id
<
row_stride
)
{
float2
vals_vec
=
input_cast
[
row
*
row_stride
+
i
*
loop_stride
+
id
];
float2
bias_vec
=
bias_cast
[
i
*
loop_stride
+
id
];
__half2
*
vals_half
=
reinterpret_cast
<
__half2
*>
(
&
vals_vec
);
__half2
*
bias_half
=
reinterpret_cast
<
__half2
*>
(
&
bias_vec
);
float2
low_data
=
__half22float2
(
vals_half
[
0
]);
float2
high_data
=
__half22float2
(
vals_half
[
1
]);
float2
low_bias
=
__half22float2
(
bias_half
[
0
]);
float2
high_bias
=
__half22float2
(
bias_half
[
1
]);
low_data
.
x
+=
low_bias
.
x
;
low_data
.
y
+=
low_bias
.
y
;
high_data
.
x
+=
high_bias
.
x
;
high_data
.
y
+=
high_bias
.
y
;
low_data
.
x
=
gelu
(
low_data
.
x
);
low_data
.
y
=
gelu
(
low_data
.
y
);
high_data
.
x
=
gelu
(
high_data
.
x
);
high_data
.
y
=
gelu
(
high_data
.
y
);
vals_half
[
0
]
=
__float22half2_rn
(
low_data
);
vals_half
[
1
]
=
__float22half2_rn
(
high_data
);
vals_cast
[
row
*
row_stride
+
i
*
loop_stride
+
id
]
=
vals_vec
;
}
}
#endif
}
__global__
void
d_gelu_func
(
float
*
d_output
,
const
float
*
gelu_input
,
const
float
*
bias
,
int
intermediate_size
)
{
int
row
=
blockIdx
.
x
;
int
id
=
threadIdx
.
x
;
int
loop_stride
=
blockDim
.
x
;
int
iterations
=
intermediate_size
/
blockDim
.
x
/
4
;
int
row_stride
=
intermediate_size
/
4
;
float4
*
d_output_cast
=
reinterpret_cast
<
float4
*>
(
d_output
);
const
float4
*
gelu_input_cast
=
reinterpret_cast
<
const
float4
*>
(
gelu_input
);
const
float4
*
bias_cast
=
reinterpret_cast
<
const
float4
*>
(
bias
);
for
(
int
i
=
0
;
i
<
iterations
;
i
++
)
{
if
(
i
*
loop_stride
+
id
<
row_stride
)
{
float4
output_data
=
d_output_cast
[
row
*
row_stride
+
i
*
loop_stride
+
id
];
float4
gelu_input_data
=
gelu_input_cast
[
row
*
row_stride
+
i
*
loop_stride
+
id
];
float4
bias_data
=
bias_cast
[
i
*
loop_stride
+
id
];
gelu_input_data
.
x
+=
bias_data
.
x
;
gelu_input_data
.
y
+=
bias_data
.
y
;
gelu_input_data
.
z
+=
bias_data
.
z
;
gelu_input_data
.
w
+=
bias_data
.
w
;
output_data
.
x
*=
d_gelu
(
gelu_input_data
.
x
);
output_data
.
y
*=
d_gelu
(
gelu_input_data
.
y
);
output_data
.
z
*=
d_gelu
(
gelu_input_data
.
z
);
output_data
.
w
*=
d_gelu
(
gelu_input_data
.
w
);
d_output_cast
[
row
*
row_stride
+
i
*
loop_stride
+
id
]
=
output_data
;
}
}
}
__global__
void
d_gelu_func
(
__half
*
d_output
,
const
__half
*
gelu_input
,
const
__half
*
bias
,
int
intermediate_size
)
{
#if __CUDA_ARCH__ >= 700
int
row
=
blockIdx
.
x
;
int
id
=
threadIdx
.
x
;
int
loop_stride
=
blockDim
.
x
;
int
iterations
=
intermediate_size
/
blockDim
.
x
/
4
;
int
row_stride
=
intermediate_size
/
4
;
float2
*
d_output_cast
=
reinterpret_cast
<
float2
*>
(
d_output
);
const
float2
*
gelu_input_cast
=
reinterpret_cast
<
const
float2
*>
(
gelu_input
);
const
float2
*
bias_cast
=
reinterpret_cast
<
const
float2
*>
(
bias
);
#pragma unroll
for
(
int
i
=
0
;
i
<
iterations
;
i
++
)
{
if
(
i
*
loop_stride
+
id
<
row_stride
)
{
float2
output_data
=
d_output_cast
[
row
*
row_stride
+
i
*
loop_stride
+
id
];
float2
gelu_input_data
=
gelu_input_cast
[
row
*
row_stride
+
i
*
loop_stride
+
id
];
float2
bias_vec
=
bias_cast
[
i
*
loop_stride
+
id
];
__half2
*
output_data_half
=
reinterpret_cast
<
__half2
*>
(
&
output_data
);
__half2
*
gelu_input_data_half
=
reinterpret_cast
<
__half2
*>
(
&
gelu_input_data
);
__half2
*
bias_half
=
reinterpret_cast
<
__half2
*>
(
&
bias_vec
);
float2
output_half_0
=
__half22float2
(
output_data_half
[
0
]);
float2
output_half_1
=
__half22float2
(
output_data_half
[
1
]);
float2
gelu_input_half_0
=
__half22float2
(
gelu_input_data_half
[
0
]);
float2
gelu_input_half_1
=
__half22float2
(
gelu_input_data_half
[
1
]);
float2
bias_half_0
=
__half22float2
(
bias_half
[
0
]);
float2
bias_half_1
=
__half22float2
(
bias_half
[
1
]);
gelu_input_half_0
.
x
+=
bias_half_0
.
x
;
gelu_input_half_0
.
y
+=
bias_half_0
.
y
;
gelu_input_half_1
.
x
+=
bias_half_1
.
x
;
gelu_input_half_1
.
y
+=
bias_half_1
.
y
;
output_half_0
.
x
*=
d_gelu
(
gelu_input_half_0
.
x
);
output_half_0
.
y
*=
d_gelu
(
gelu_input_half_0
.
y
);
output_half_1
.
x
*=
d_gelu
(
gelu_input_half_1
.
x
);
output_half_1
.
y
*=
d_gelu
(
gelu_input_half_1
.
y
);
float2
result
;
__half2
*
result_half2
=
reinterpret_cast
<
__half2
*>
(
&
result
);
result_half2
[
0
]
=
__float22half2_rn
(
output_half_0
);
result_half2
[
1
]
=
__float22half2_rn
(
output_half_1
);
d_output_cast
[
row
*
row_stride
+
i
*
loop_stride
+
id
]
=
result
;
}
}
#endif
}
template
<
typename
T
>
void
launch_bias_gelu
(
const
T
*
input
,
const
T
*
bias
,
T
*
output
,
int
intermediate_size
,
int
batch_size
,
cudaStream_t
stream
)
{
int
iterations
=
(
intermediate_size
+
1023
)
/
1024
;
int
threads
=
intermediate_size
/
iterations
/
4
;
dim3
block_dims
(
threads
);
dim3
grid_dims
(
batch_size
);
fused_bias_gelu
<<<
grid_dims
,
block_dims
,
0
,
stream
>>>
(
input
,
bias
,
output
,
intermediate_size
);
}
template
<
typename
T
>
void
launch_gelu
(
const
T
*
input
,
T
*
output
,
int
intermediate_size
,
int
batch_size
,
cudaStream_t
stream
)
{
int
iterations
=
(
intermediate_size
+
1023
)
/
1024
;
int
threads
=
intermediate_size
/
iterations
/
4
;
dim3
block_dims
(
threads
);
dim3
grid_dims
(
batch_size
);
gelu_kernel
<<<
grid_dims
,
block_dims
,
0
,
stream
>>>
(
input
,
output
,
intermediate_size
);
}
template
void
launch_bias_gelu
<
float
>(
const
float
*
,
const
float
*
,
float
*
,
int
,
int
,
cudaStream_t
);
template
void
launch_bias_gelu
<
__half
>(
const
__half
*
,
const
__half
*
,
__half
*
,
int
,
int
,
cudaStream_t
);
template
void
launch_gelu
<
float
>(
const
float
*
,
float
*
,
int
,
int
,
cudaStream_t
);
template
void
launch_gelu
<
__half
>(
const
__half
*
,
__half
*
,
int
,
int
,
cudaStream_t
);
template
<
typename
T
>
void
launch_d_gelu
(
T
*
d_output
,
const
T
*
input
,
const
T
*
bias
,
int
intermediate_size
,
int
batch_size
,
cudaStream_t
stream
)
{
int
iterations
=
(
intermediate_size
+
1023
)
/
1024
;
int
threads
=
intermediate_size
/
iterations
/
4
;
dim3
block_dims
(
threads
);
dim3
grid_dims
(
batch_size
);
d_gelu_func
<<<
grid_dims
,
block_dims
,
0
,
stream
>>>
(
d_output
,
input
,
bias
,
intermediate_size
);
}
template
void
launch_d_gelu
<
float
>(
float
*
,
const
float
*
,
const
float
*
,
int
,
int
,
cudaStream_t
);
template
void
launch_d_gelu
<
__half
>(
__half
*
,
const
__half
*
,
const
__half
*
,
int
,
int
,
cudaStream_t
);
deepspeed/ops/csrc/transformer/general_kernels.cu
0 → 100644
View file @
eadbbe09
#include "general_kernels.h"
namespace
cg
=
cooperative_groups
;
template
<
typename
T
>
__global__
void
column_sum_reduce
(
const
T
*
__restrict__
inp
,
T
*
__restrict__
out
,
int
rows
,
int
width
)
{
__shared__
float
tile
[
TILE_DIM
][
TILE_DIM
+
1
];
cg
::
thread_block
b
=
cg
::
this_thread_block
();
cg
::
thread_block_tile
<
TILE_DIM
>
g
=
cg
::
tiled_partition
<
TILE_DIM
>
(
b
);
int
idx
=
blockDim
.
x
*
blockIdx
.
x
+
threadIdx
.
x
;
int
y_stride
=
width
*
TILE_DIM
;
float
localSum
=
0
;
// Loop across matrix height
if
(
idx
<
width
)
{
int
offset
=
threadIdx
.
y
*
width
+
idx
;
for
(
int
r
=
threadIdx
.
y
;
r
<
rows
;
r
+=
TILE_DIM
)
{
localSum
+=
(
float
)
inp
[
offset
];
offset
+=
y_stride
;
}
}
tile
[
threadIdx
.
x
][
threadIdx
.
y
]
=
localSum
;
__syncthreads
();
// Sum the shared buffer.
float
sum
=
tile
[
threadIdx
.
y
][
threadIdx
.
x
];
#ifndef __STOCHASTIC_MODE__
__syncthreads
();
#endif
for
(
int
i
=
1
;
i
<
TILE_DIM
;
i
<<=
1
)
sum
+=
g
.
shfl_down
(
sum
,
i
);
if
(
threadIdx
.
x
==
0
)
{
int
pos
=
blockIdx
.
x
*
TILE_DIM
+
threadIdx
.
y
;
if
(
pos
<
width
)
out
[
pos
]
=
sum
;
}
}
template
<
typename
T
>
void
launch_fuse_transpose_bias_kernel
(
const
T
*
inp
,
T
*
out
,
int
rows
,
int
cols
,
cudaStream_t
stream
);
template
<
>
void
launch_fuse_transpose_bias_kernel
<
float
>
(
const
float
*
inp
,
float
*
out
,
int
rows
,
int
cols
,
cudaStream_t
stream
)
{
// assert(rows % TILE_DIM == 0);
// assert(cols % TILE_DIM == 0);
dim3
grid_dim
((
cols
-
1
)
/
TILE_DIM
+
1
);
dim3
block_dim
(
TILE_DIM
,
TILE_DIM
);
column_sum_reduce
<
float
><<<
grid_dim
,
block_dim
,
0
,
stream
>>>
(
inp
,
out
,
rows
,
cols
);
}
template
<
>
void
launch_fuse_transpose_bias_kernel
<
__half
>
(
const
__half
*
inp
,
__half
*
out
,
int
rows
,
int
cols
,
cudaStream_t
stream
)
{
// assert(rows % TILE_DIM == 0);
// assert(cols % TILE_DIM == 0);
dim3
grid_dim
((
cols
-
1
)
/
TILE_DIM
+
1
);
dim3
block_dim
(
TILE_DIM
,
TILE_DIM
);
column_sum_reduce
<
__half
><<<
grid_dim
,
block_dim
,
0
,
stream
>>>
(
inp
,
out
,
rows
,
cols
);
}
__global__
void
fused_add2_kernel
(
const
int
N
,
float
*
out
,
const
float
*
inp1
,
const
float
*
inp2
)
{
const
float4
*
inp1_4
=
reinterpret_cast
<
const
float4
*>
(
inp1
);
const
float4
*
inp2_4
=
reinterpret_cast
<
const
float4
*>
(
inp2
);
float4
*
out_4
=
reinterpret_cast
<
float4
*>
(
out
);
CUDA_1D_KERNEL_LOOP
(
j
,
N
)
{
float4
val
;
float4
inp1_reg
=
inp1_4
[
j
];
float4
inp2_reg
=
inp2_4
[
j
];
val
.
x
=
inp1_reg
.
x
+
inp2_reg
.
x
;
val
.
y
=
inp1_reg
.
y
+
inp2_reg
.
y
;
val
.
z
=
inp1_reg
.
z
+
inp2_reg
.
z
;
val
.
w
=
inp1_reg
.
w
+
inp2_reg
.
w
;
out_4
[
j
]
=
val
;
}
}
__global__
void
fused_add2_kernel
(
const
int
N
,
__half
*
out
,
const
__half
*
inp1
,
const
__half
*
inp2
)
{
float2
inp1_4
;
float2
inp2_4
;
__half2
*
inp1_h
=
reinterpret_cast
<
__half2
*>
(
&
inp1_4
);
__half2
*
inp2_h
=
reinterpret_cast
<
__half2
*>
(
&
inp2_4
);
const
float2
*
inp1_arr
=
reinterpret_cast
<
const
float2
*>
(
inp1
);
const
float2
*
inp2_arr
=
reinterpret_cast
<
const
float2
*>
(
inp2
);
CUDA_1D_KERNEL_LOOP
(
j
,
N
)
{
inp1_4
=
inp1_arr
[
j
];
inp2_4
=
inp2_arr
[
j
];
float2
inp1_h_f_0
=
__half22float2
(
inp1_h
[
0
]);
float2
inp1_h_f_1
=
__half22float2
(
inp1_h
[
1
]);
float2
inp2_h_f_0
=
__half22float2
(
inp2_h
[
0
]);
float2
inp2_h_f_1
=
__half22float2
(
inp2_h
[
1
]);
inp1_h_f_0
.
x
+=
inp2_h_f_0
.
x
;
inp1_h_f_0
.
y
+=
inp2_h_f_0
.
y
;
inp1_h_f_1
.
x
+=
inp2_h_f_1
.
x
;
inp1_h_f_1
.
y
+=
inp2_h_f_1
.
y
;
float2
val_f
;
__half2
*
val_h
=
reinterpret_cast
<
__half2
*>
(
&
val_f
);
val_h
[
0
]
=
__float22half2_rn
(
inp1_h_f_0
);
val_h
[
1
]
=
__float22half2_rn
(
inp1_h_f_1
);
float2
*
out_4
=
reinterpret_cast
<
float2
*>
(
out
);
out_4
[
j
]
=
val_f
;
}
}
template
<
>
void
launch_fused_add2
<
float
>
(
float
*
out
,
const
float
*
inp1
,
const
float
*
inp2
,
int
batch_size
,
int
seq_length
,
int
hidden_dim
,
cudaStream_t
&
stream
)
{
int
total_count
=
batch_size
*
seq_length
*
hidden_dim
/
4
;
dim3
grid_dim
=
DS_GET_BLOCKS
(
total_count
);
//(batch_size * seq_length);
dim3
block_dim
=
DS_CUDA_NUM_THREADS
;
//(hidden_dim / 4);
fused_add2_kernel
<<<
grid_dim
,
block_dim
,
0
,
stream
>>>
(
total_count
,
out
,
inp1
,
inp2
);
}
template
<
>
void
launch_fused_add2
<
__half
>
(
__half
*
out
,
const
__half
*
inp1
,
const
__half
*
inp2
,
int
batch_size
,
int
seq_length
,
int
hidden_dim
,
cudaStream_t
&
stream
)
{
int
total_count
=
batch_size
*
seq_length
*
hidden_dim
/
4
;
dim3
grid_dim
=
DS_GET_BLOCKS
(
total_count
);
//(batch_size * seq_length);
dim3
block_dim
=
DS_CUDA_NUM_THREADS
;
//(hidden_dim / 4);
fused_add2_kernel
<<<
grid_dim
,
block_dim
,
0
,
stream
>>>
(
total_count
,
out
,
inp1
,
inp2
);
}
__global__
void
fused_add3_kernel
(
float
*
out
,
const
float
*
inp1
,
const
float
*
inp2
,
const
float
*
inp3
,
int
size
,
int
row_stride
)
{
int
row
=
blockIdx
.
x
;
int
id
=
threadIdx
.
x
;
const
float4
*
inp1_4
=
reinterpret_cast
<
const
float4
*>
(
inp1
);
const
float4
*
inp2_4
=
reinterpret_cast
<
const
float4
*>
(
inp2
);
const
float4
*
inp3_4
=
reinterpret_cast
<
const
float4
*>
(
inp3
);
float4
*
out_4
=
reinterpret_cast
<
float4
*>
(
out
);
float4
val
;
float4
inp1_reg
=
inp1_4
[
row
*
row_stride
+
id
];
float4
inp2_reg
=
inp2_4
[
row
*
row_stride
+
id
];
float4
inp3_reg
=
inp3_4
[
row
*
row_stride
+
id
];
val
.
x
=
inp1_reg
.
x
+
inp2_reg
.
x
+
inp3_reg
.
x
;
val
.
y
=
inp1_reg
.
y
+
inp2_reg
.
y
+
inp3_reg
.
y
;
val
.
z
=
inp1_reg
.
z
+
inp2_reg
.
z
+
inp3_reg
.
z
;
val
.
w
=
inp1_reg
.
w
+
inp2_reg
.
w
+
inp3_reg
.
w
;
out_4
[
row
*
row_stride
+
id
]
=
val
;
}
__global__
void
fused_add3_kernel
(
__half
*
out
,
const
__half
*
inp1
,
const
__half
*
inp2
,
const
__half
*
inp3
,
int
size
,
int
row_stride
)
{
int
row
=
blockIdx
.
x
;
int
id
=
threadIdx
.
x
;
const
float2
*
inp1_arr
=
reinterpret_cast
<
const
float2
*>
(
inp1
);
const
float2
*
inp2_arr
=
reinterpret_cast
<
const
float2
*>
(
inp2
);
const
float2
*
inp3_arr
=
reinterpret_cast
<
const
float2
*>
(
inp3
);
float2
inp1_4
=
inp1_arr
[
row
*
row_stride
+
id
];
float2
inp2_4
=
inp2_arr
[
row
*
row_stride
+
id
];
float2
inp3_4
=
inp3_arr
[
row
*
row_stride
+
id
];
__half2
*
inp1_h
=
reinterpret_cast
<
__half2
*>
(
&
inp1_4
);
__half2
*
inp2_h
=
reinterpret_cast
<
__half2
*>
(
&
inp2_4
);
__half2
*
inp3_h
=
reinterpret_cast
<
__half2
*>
(
&
inp3_4
);
float2
inp1_h_f_0
=
__half22float2
(
inp1_h
[
0
]);
float2
inp1_h_f_1
=
__half22float2
(
inp1_h
[
1
]);
float2
inp2_h_f_0
=
__half22float2
(
inp2_h
[
0
]);
float2
inp2_h_f_1
=
__half22float2
(
inp2_h
[
1
]);
float2
inp3_h_f_0
=
__half22float2
(
inp3_h
[
0
]);
float2
inp3_h_f_1
=
__half22float2
(
inp3_h
[
1
]);
inp1_h_f_0
.
x
+=
(
inp2_h_f_0
.
x
+
inp3_h_f_0
.
x
);
inp1_h_f_0
.
y
+=
(
inp2_h_f_0
.
y
+
inp3_h_f_0
.
y
);
inp1_h_f_1
.
x
+=
(
inp2_h_f_1
.
x
+
inp3_h_f_1
.
x
);
inp1_h_f_1
.
y
+=
(
inp2_h_f_1
.
y
+
inp3_h_f_1
.
y
);
float2
val_f
;
__half2
*
val_h
=
reinterpret_cast
<
__half2
*>
(
&
val_f
);
val_h
[
0
]
=
__float22half2_rn
(
inp1_h_f_0
);
val_h
[
1
]
=
__float22half2_rn
(
inp1_h_f_1
);
float2
*
out_4
=
reinterpret_cast
<
float2
*>
(
out
);
out_4
[
row
*
row_stride
+
id
]
=
val_f
;
}
template
<
>
void
launch_fused_add3
<
float
>
(
float
*
out
,
const
float
*
inp1
,
const
float
*
inp2
,
const
float
*
inp3
,
int
batch_size
,
int
seq_length
,
int
hidden_size
,
cudaStream_t
&
stream
)
{
dim3
grid_dim
(
batch_size
*
seq_length
);
dim3
block_dim
(
hidden_size
/
4
);
fused_add3_kernel
<<<
grid_dim
,
block_dim
,
0
,
stream
>>>
(
out
,
inp1
,
inp2
,
inp3
,
(
batch_size
*
seq_length
*
hidden_size
),
hidden_size
/
4
);
}
template
<
>
void
launch_fused_add3
<
__half
>
(
__half
*
out
,
const
__half
*
inp1
,
const
__half
*
inp2
,
const
__half
*
inp3
,
int
batch_size
,
int
seq_length
,
int
hidden_size
,
cudaStream_t
&
stream
)
{
dim3
grid_dim
(
batch_size
*
seq_length
);
dim3
block_dim
(
hidden_size
/
4
);
fused_add3_kernel
<<<
grid_dim
,
block_dim
,
0
,
stream
>>>
(
out
,
inp1
,
inp2
,
inp3
,
(
batch_size
*
seq_length
*
hidden_size
),
hidden_size
/
4
);
}
__global__
void
fused_add4_kernel
(
float
*
out
,
const
float
*
inp1
,
const
float
*
inp2
,
const
float
*
inp3
,
const
float
*
inp4
,
int
size
,
int
row_stride
)
{
int
row
=
blockIdx
.
x
;
int
id
=
threadIdx
.
x
;
const
float4
*
inp1_4
=
reinterpret_cast
<
const
float4
*>
(
inp1
);
const
float4
*
inp2_4
=
reinterpret_cast
<
const
float4
*>
(
inp2
);
const
float4
*
inp3_4
=
reinterpret_cast
<
const
float4
*>
(
inp3
);
const
float4
*
inp4_4
=
reinterpret_cast
<
const
float4
*>
(
inp4
);
float4
*
out_4
=
reinterpret_cast
<
float4
*>
(
out
);
float4
val
;
float4
inp1_reg
=
inp1_4
[
row
*
row_stride
+
id
];
float4
inp2_reg
=
inp2_4
[
row
*
row_stride
+
id
];
float4
inp3_reg
=
inp3_4
[
row
*
row_stride
+
id
];
float4
inp4_reg
=
inp4_4
[
row
*
row_stride
+
id
];
val
.
x
=
inp1_reg
.
x
+
inp2_reg
.
x
+
inp3_reg
.
x
+
inp4_reg
.
x
;
val
.
y
=
inp1_reg
.
y
+
inp2_reg
.
y
+
inp3_reg
.
y
+
inp4_reg
.
y
;
val
.
z
=
inp1_reg
.
z
+
inp2_reg
.
z
+
inp3_reg
.
z
+
inp4_reg
.
z
;
val
.
w
=
inp1_reg
.
w
+
inp2_reg
.
w
+
inp3_reg
.
w
+
inp4_reg
.
w
;
out_4
[
row
*
row_stride
+
id
]
=
val
;
}
__global__
void
fused_add4_kernel
(
__half
*
out
,
const
__half
*
inp1
,
const
__half
*
inp2
,
const
__half
*
inp3
,
const
__half
*
inp4
,
int
size
,
int
row_stride
)
{
int
row
=
blockIdx
.
x
;
int
id
=
threadIdx
.
x
;
const
float2
*
inp1_arr
=
reinterpret_cast
<
const
float2
*>
(
inp1
);
const
float2
*
inp2_arr
=
reinterpret_cast
<
const
float2
*>
(
inp2
);
const
float2
*
inp3_arr
=
reinterpret_cast
<
const
float2
*>
(
inp3
);
const
float2
*
inp4_arr
=
reinterpret_cast
<
const
float2
*>
(
inp4
);
float2
inp1_4
=
inp1_arr
[
row
*
row_stride
+
id
];
float2
inp2_4
=
inp2_arr
[
row
*
row_stride
+
id
];
float2
inp3_4
=
inp3_arr
[
row
*
row_stride
+
id
];
float2
inp4_4
=
inp4_arr
[
row
*
row_stride
+
id
];
__half2
*
inp1_h
=
reinterpret_cast
<
__half2
*>
(
&
inp1_4
);
__half2
*
inp2_h
=
reinterpret_cast
<
__half2
*>
(
&
inp2_4
);
__half2
*
inp3_h
=
reinterpret_cast
<
__half2
*>
(
&
inp3_4
);
__half2
*
inp4_h
=
reinterpret_cast
<
__half2
*>
(
&
inp4_4
);
float2
inp1_h_f_0
=
__half22float2
(
inp1_h
[
0
]);
float2
inp1_h_f_1
=
__half22float2
(
inp1_h
[
1
]);
float2
inp2_h_f_0
=
__half22float2
(
inp2_h
[
0
]);
float2
inp2_h_f_1
=
__half22float2
(
inp2_h
[
1
]);
float2
inp3_h_f_0
=
__half22float2
(
inp3_h
[
0
]);
float2
inp3_h_f_1
=
__half22float2
(
inp3_h
[
1
]);
float2
inp4_h_f_0
=
__half22float2
(
inp4_h
[
0
]);
float2
inp4_h_f_1
=
__half22float2
(
inp4_h
[
1
]);
inp1_h_f_0
.
x
+=
(
inp2_h_f_0
.
x
+
inp3_h_f_0
.
x
+
inp4_h_f_0
.
x
);
inp1_h_f_0
.
y
+=
(
inp2_h_f_0
.
y
+
inp3_h_f_0
.
y
+
inp4_h_f_0
.
y
);
inp1_h_f_1
.
x
+=
(
inp2_h_f_1
.
x
+
inp3_h_f_1
.
x
+
inp4_h_f_1
.
x
);
inp1_h_f_1
.
y
+=
(
inp2_h_f_1
.
y
+
inp3_h_f_1
.
y
+
inp4_h_f_1
.
y
);
float2
val_f
;
__half2
*
val_h
=
reinterpret_cast
<
__half2
*>
(
&
val_f
);
val_h
[
0
]
=
__float22half2_rn
(
inp1_h_f_0
);
val_h
[
1
]
=
__float22half2_rn
(
inp1_h_f_1
);
float2
*
out_4
=
reinterpret_cast
<
float2
*>
(
out
);
out_4
[
row
*
row_stride
+
id
]
=
val_f
;
}
template
<
>
void
launch_fused_add4
<
float
>
(
float
*
out
,
const
float
*
inp1
,
const
float
*
inp2
,
const
float
*
inp3
,
const
float
*
inp4
,
int
batch_size
,
int
seq_length
,
int
hidden_size
,
cudaStream_t
&
stream
)
{
dim3
grid_dim
(
batch_size
*
seq_length
);
dim3
block_dim
(
hidden_size
/
4
);
fused_add4_kernel
<<<
grid_dim
,
block_dim
,
0
,
stream
>>>
(
out
,
inp1
,
inp2
,
inp3
,
inp4
,
(
batch_size
*
seq_length
*
hidden_size
),
hidden_size
/
4
);
}
template
<
>
void
launch_fused_add4
<
__half
>
(
__half
*
out
,
const
__half
*
inp1
,
const
__half
*
inp2
,
const
__half
*
inp3
,
const
__half
*
inp4
,
int
batch_size
,
int
seq_length
,
int
hidden_size
,
cudaStream_t
&
stream
)
{
dim3
grid_dim
(
batch_size
*
seq_length
);
dim3
block_dim
(
hidden_size
/
4
);
fused_add4_kernel
<<<
grid_dim
,
block_dim
,
0
,
stream
>>>
(
out
,
inp1
,
inp2
,
inp3
,
inp4
,
(
batch_size
*
seq_length
*
hidden_size
),
hidden_size
/
4
);
}
Prev
1
2
3
4
5
6
7
8
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment