Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
deepspeed
Commits
eadbbe09
Commit
eadbbe09
authored
Apr 25, 2021
by
401qingkong
Browse files
push rocm deepspeed v0.3.13
parent
ab5534fc
Changes
155
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
5555 additions
and
0 deletions
+5555
-0
csrc/includes/hip/ds_transformer_hip.h
csrc/includes/hip/ds_transformer_hip.h
+184
-0
csrc/includes/hip/feed_forward.h
csrc/includes/hip/feed_forward.h
+97
-0
csrc/includes/hip/gelu.h
csrc/includes/hip/gelu.h
+36
-0
csrc/includes/hip/gemm_test.h
csrc/includes/hip/gemm_test.h
+331
-0
csrc/includes/hip/gemm_test.h.bak
csrc/includes/hip/gemm_test.h.bak
+295
-0
csrc/includes/hip/general_kernels.h
csrc/includes/hip/general_kernels.h
+47
-0
csrc/includes/hip/normalize_layer.h
csrc/includes/hip/normalize_layer.h
+202
-0
csrc/includes/hip/softmax.h
csrc/includes/hip/softmax.h
+60
-0
csrc/includes/hip/strided_batch_gemm.h
csrc/includes/hip/strided_batch_gemm.h
+183
-0
csrc/includes/hip/strided_batch_gemm.h.bak
csrc/includes/hip/strided_batch_gemm.h.bak
+179
-0
csrc/includes/hip/type_shim.h
csrc/includes/hip/type_shim.h
+110
-0
csrc/lamb/hip/fused_lamb_hip.cpp
csrc/lamb/hip/fused_lamb_hip.cpp
+109
-0
csrc/lamb/hip/fused_lamb_hip_kernel.hip
csrc/lamb/hip/fused_lamb_hip_kernel.hip
+514
-0
csrc/sparse_attention/hip/utils.cpp
csrc/sparse_attention/hip/utils.cpp
+121
-0
csrc/transformer/hip/cublas_wrappers.hip
csrc/transformer/hip/cublas_wrappers.hip
+223
-0
csrc/transformer/hip/cublas_wrappers.hip.bak
csrc/transformer/hip/cublas_wrappers.hip.bak
+199
-0
csrc/transformer/hip/dropout_kernels.hip
csrc/transformer/hip/dropout_kernels.hip
+869
-0
csrc/transformer/hip/ds_transformer_hip.cpp
csrc/transformer/hip/ds_transformer_hip.cpp
+1048
-0
csrc/transformer/hip/gelu_kernels.hip
csrc/transformer/hip/gelu_kernels.hip
+336
-0
csrc/transformer/hip/general_kernels.hip
csrc/transformer/hip/general_kernels.hip
+412
-0
No files found.
csrc/includes/hip/ds_transformer_hip.h
0 → 100644
View file @
eadbbe09
#pragma once
#include <hip/hip_runtime_api.h>
#include <hiprand.h>
#include <memory>
#include <vector>
#include "rocblas.h"
#include "hip/hip_runtime.h"
#include "hip/dropout.h"
#include "hip/feed_forward.h"
#include "hip/gelu.h"
#include "hip/general_kernels.h"
#include "hip/normalize_layer.h"
#include "hip/softmax.h"
#include "hip/strided_batch_gemm.h"
struct
BertGemmAlgos
{
int
m_gemm_qkv_algo
;
int
m_gemm_inter_algo
;
int
m_gemm_output_algo
;
int
m_gemm_batch1_algo
;
int
m_gemm_batch2_algo
;
BertGemmAlgos
()
:
m_gemm_qkv_algo
(
-
1
),
m_gemm_inter_algo
(
-
1
),
m_gemm_output_algo
(
-
1
),
m_gemm_batch1_algo
(
-
1
),
m_gemm_batch2_algo
(
-
1
)
{
}
};
template
<
typename
T
>
class
BertTransformerLayer
{
public:
BertTransformerLayer
(
int
layer_id
,
int
batch_size
,
int
hidden_size
,
int
num_heads
,
int
intermediate_size
,
int
seq_length
,
float
attn_dropout_ratio
,
float
hidden_output_dropout_ratio
,
float
layer_norm_eps
,
bool
pre_or_postLayerNorm
,
const
std
::
vector
<
std
::
array
<
int
,
3
>>&
gemm_algos
,
bool
attn_dropout_checkpoint
,
bool
normalize_invertible
,
bool
gelu_checkpoint
,
bool
stochastic_mode
);
virtual
~
BertTransformerLayer
();
void
Forward
(
int
bsz
,
const
T
*
input_ptr
,
const
T
*
input_mask_ptr
,
const
T
*
attn_qkvw_ptr
,
const
T
*
attn_qkvb_ptr
,
const
T
*
attn_ow_ptr
,
const
T
*
attn_ob_ptr
,
const
T
*
attn_nw_ptr
,
const
T
*
attn_nb_ptr
,
const
T
*
inter_w_ptr
,
const
T
*
inter_b_ptr
,
const
T
*
output_w_ptr
,
const
T
*
output_b_ptr
,
const
T
*
norm_w_ptr
,
const
T
*
norm_b_ptr
,
T
*
out_ptr
,
T
*
inp_norm_ptr
,
T
*
q_tf_ptr
,
T
*
k_tf_ptr
,
T
*
v_tf_ptr
,
T
*
softmax_output_ptr
,
T
*
ctx_bufB_ptr
,
T
*
attn_o_inp_ptr
,
T
*
add_res_ptr
,
T
*
ff1_inp_ptr
,
T
*
gelu_inp_ptr
,
T
*
ff2_inp_ptr
);
void
Backward
(
int
bsz
,
const
T
*
grad_output_ptr
,
const
T
*
input_ptr
,
const
T
*
output_ptr
,
const
T
*
inp_norm_ptr
,
const
T
*
q_tf_ptr
,
const
T
*
k_tf_ptr
,
const
T
*
v_tf_ptr
,
const
T
*
softmax_output_ptr
,
const
T
*
ctx_bufB_ptr
,
const
T
*
attn_o_inp_ptr
,
const
T
*
add_res_ptr
,
const
T
*
ff1_inp_ptr
,
const
T
*
gelu_inp_ptr
,
const
T
*
ff2_inp_ptr
,
const
T
*
input_mask_ptr
,
const
T
*
attn_qkvw_ptr
,
const
T
*
attn_ow_ptr
,
const
T
*
attn_nw_ptr
,
const
T
*
attn_nb_ptr
,
const
T
*
inter_w_ptr
,
const
T
*
inter_b_ptr
,
const
T
*
output_w_ptr
,
const
T
*
norm_w_ptr
,
const
T
*
norm_b_ptr
,
T
*
grad_input_ptr
,
T
*
grad_attn_qkvw_ptr
,
T
*
grad_attn_qkvb_ptr
,
T
*
grad_attn_ow_ptr
,
T
*
grad_attn_ob_ptr
,
T
*
grad_attn_nw_ptr
,
T
*
grad_attn_nb_ptr
,
T
*
grad_inter_w_ptr
,
T
*
grad_inter_b_ptr
,
T
*
grad_output_w_ptr
,
T
*
grad_output_b_ptr
,
T
*
grad_norm_w_ptr
,
T
*
grad_norm_b_ptr
);
void
SetIntermediateBuffers
(
uint8_t
*
attn_prob_dropout_mask_ptr
,
uint8_t
*
attn_output_dropout_mask_ptr
,
uint8_t
*
layer_output_dropout_mask_ptr
,
T
*
layer_norm_var
,
T
*
layer_norm_mean
,
T
*
attn_layer_norm_var
,
T
*
attn_layer_norm_mean
);
inline
int
GetBatchSize
()
const
{
return
_batch_size
;
}
inline
int
GetNumHeads
()
const
{
return
_heads
;
}
inline
int
GetSeqLength
()
const
{
return
_seq_length
;
}
inline
int
GetIntermediateSize
()
const
{
return
_intermediate_size
;
}
void
SetSeqLength
(
int
seq_len
);
inline
int
GetHiddenSize
()
const
{
return
_hidden_size
;
}
void
SetTrainingMode
(
bool
training
);
inline
bool
IsTrainingMode
()
const
{
return
_training
;
}
inline
bool
GeluCheckpoint
()
const
{
return
_gelu_checkpoint
;
}
private:
void
Initialize
();
size_t
getWorkspaceSize
(
int
maxBatchSize
)
const
;
// Params
int
_layer_id
;
int
_batch_size
;
int
_hidden_size
;
int
_heads
;
int
_size_per_head
;
int
_intermediate_size
;
int
_seq_length
;
bool
_pre_or_postLayerNorm
;
rocblas_handle
_cublasHandle
;
hipStream_t
_stream
;
// layers
FeedForward
<
T
>
_qkv_linear
;
FeedForward
<
T
>
_attn_out_linear
;
Normalize_Layer
<
T
>
_attn_layer_norm
;
Normalize_Layer
<
T
>
_layer_norm
;
Normalize_Layer
<
T
>*
_last_normalize
;
FeedForward
<
T
>
_ff1
,
_ff2
;
Softmax
<
T
>
_softmax
;
Gelu
<
T
>
_gelu
;
Dropout
<
T
>
_attn_prob_dropout
;
Dropout
<
T
>
_attn_output_dropout
;
Dropout
<
T
>
_layer_output_dropout
;
StridedBatchGemm
<
T
>
_attn_scores
;
StridedBatchGemm
<
T
>
_attn_context
;
bool
_training
;
// Memory saving flags
bool
_attn_dropout_checkpoint
;
bool
_normalize_invertible
;
bool
_gelu_checkpoint
;
// High Performace flags
bool
_stochastic_mode
;
};
csrc/includes/hip/feed_forward.h
0 → 100644
View file @
eadbbe09
#ifndef __FEEDFORWARD_H__
#define __FEEDFORWARD_H__
#include <hip/hip_runtime.h>
#include <hip/hip_fp16.h>
#include <stdio.h>
#include "hip/custom_hip_layers.h"
template
<
typename
T
>
class
FeedForward
{
public:
struct
Config
{
int
batchSize
,
outputSize
;
int
inputSize
;
std
::
array
<
int
,
3
>
gemm_algos
;
Config
(
int
batch
,
int
outputs
,
int
inputs
,
const
std
::
array
<
int
,
3
>&
algos
)
:
batchSize
(
batch
),
outputSize
(
outputs
),
inputSize
(
inputs
),
gemm_algos
(
algos
)
{
}
};
FeedForward
(
Config
config
)
:
config_
(
config
)
{}
~
FeedForward
()
{}
void
Forward
(
int
bsz
,
const
T
*
input_ptr
,
const
T
*
weights
,
T
*
out
,
rocblas_handle
&
_cublasHandle
)
{
float
alpha
=
T
(
1.
);
float
beta
=
T
(
0.
);
cublas_gemm_ex
(
_cublasHandle
,
rocblas_operation_transpose
,
rocblas_operation_none
,
config_
.
outputSize
,
bsz
,
config_
.
inputSize
,
&
alpha
,
&
beta
,
weights
,
input_ptr
,
out
,
//cublasGemmAlgo_t(config_.gemm_algos[0]));
rocblas_gemm_algo
(
config_
.
gemm_algos
[
0
]));
}
void
Backward
(
int
bsz
,
const
T
*
out_grad
,
const
T
*
input_ptr
,
const
T
*
weights
,
T
*
weights_grad
,
T
*
bias_grad
,
rocblas_handle
&
_cublasHandle
,
hipStream_t
&
stream
,
T
*
inp_grad_out
=
nullptr
,
T
*
out_grad_trans_out
=
nullptr
)
{
float
alpha
=
(
T
)
1.0
,
beta
=
(
T
)
0.0
;
cublas_gemm_ex
(
_cublasHandle
,
rocblas_operation_none
,
rocblas_operation_transpose
,
config_
.
inputSize
,
config_
.
outputSize
,
bsz
,
&
alpha
,
&
beta
,
input_ptr
,
out_grad
,
weights_grad
,
//cublasGemmAlgo_t(config_.gemm_algos[1]));
rocblas_gemm_algo
(
config_
.
gemm_algos
[
1
]));
cublas_gemm_ex
(
_cublasHandle
,
rocblas_operation_none
,
rocblas_operation_none
,
config_
.
inputSize
,
bsz
,
config_
.
outputSize
,
&
alpha
,
&
beta
,
weights
,
out_grad
,
inp_grad_out
,
//cublasGemmAlgo_t(config_.gemm_algos[2]));
rocblas_gemm_algo
(
config_
.
gemm_algos
[
2
]));
launch_fuse_transpose_bias_kernel
<
T
>
(
out_grad
,
bias_grad
,
bsz
,
config_
.
outputSize
,
stream
);
}
private:
Config
config_
;
};
#endif
csrc/includes/hip/gelu.h
0 → 100644
View file @
eadbbe09
#pragma once
#include <hip/hip_runtime.h>
#include <hip/hip_fp16.h>
#include <stdio.h>
#include "hip/custom_hip_layers.h"
template
<
typename
T
>
class
Gelu
{
public:
struct
Config
{
uint32_t
intermediate_size
;
Config
(
uint32_t
inter_size
)
:
intermediate_size
(
inter_size
)
{}
};
Gelu
(
const
Config
&
config
)
:
_config
(
config
)
{}
virtual
~
Gelu
()
{}
void
ForwardWithBiasAdd
(
int
bsz
,
const
T
*
input_buf
,
const
T
*
bias
,
T
*
output
,
hipStream_t
stream
)
{
launch_bias_gelu
<
T
>
(
input_buf
,
bias
,
output
,
_config
.
intermediate_size
,
bsz
,
stream
);
}
void
Backward
(
int
bsz
,
T
*
d_output
,
const
T
*
input_buf
,
const
T
*
bias
,
hipStream_t
stream
)
{
launch_d_gelu
<
T
>
(
d_output
,
input_buf
,
bias
,
_config
.
intermediate_size
,
bsz
,
stream
);
}
private:
Config
_config
;
};
csrc/includes/hip/gemm_test.h
0 → 100644
View file @
eadbbe09
#pragma once
#include <hip/hip_fp16.h>
#ifndef __HIP_PLATFORM_HCC__
#include <cuda_profiler_api.h>
#endif
#include <array>
#include <cstdio>
#include <cstdlib>
#include <ctime>
#include <limits>
#include <memory>
#include "StopWatch.h"
#include "cublas_wrappers.h"
template
<
typename
T
>
void
check
(
T
result
,
char
const
*
const
func
,
const
char
*
const
file
,
int
const
line
)
{
if
(
result
)
{
std
::
cout
<<
(
std
::
string
(
"CUDA runtime error: "
)
+
+
file
+
":"
+
std
::
to_string
(
line
)
+
"
\n
"
);
}
}
#define check_cuda_error(val) check((val), #val, __FILE__, __LINE__)
template
<
typename
T
>
class
GemmTest
{
public:
GemmTest
(
int
m
,
int
n
,
int
k
,
rocblas_operation
ta
,
rocblas_operation
tb
,
rocblas_handle
h
)
:
M
(
m
),
N
(
n
),
K
(
k
),
transa
(
ta
),
transb
(
tb
),
handle
(
h
)
{
check_cuda_error
(
hipMalloc
((
void
**
)
&
A
,
sizeof
(
T
)
*
M
*
K
));
check_cuda_error
(
hipMalloc
((
void
**
)
&
B
,
sizeof
(
T
)
*
K
*
N
));
check_cuda_error
(
hipMalloc
((
void
**
)
&
C
,
sizeof
(
T
)
*
M
*
N
));
}
~
GemmTest
()
{
check_cuda_error
(
hipFree
(
A
));
check_cuda_error
(
hipFree
(
B
));
check_cuda_error
(
hipFree
(
C
));
}
std
::
array
<
int
,
3
>
TestAlgo
(
int
loops
)
{
float
alpha
=
(
T
)
1.0
f
;
float
beta
=
(
T
)
0.0
f
;
int
algo_fw
=
Run
(
loops
,
[
=
](
int
algo
)
{
cublas_gemm_ex
(
handle
,
rocblas_operation_transpose
,
rocblas_operation_none
,
N
,
M
,
K
,
&
alpha
,
&
beta
,
B
,
A
,
C
,
#ifdef __HIP_PLATFORM_HCC__
static_cast
<
rocblas_gemm_algo
>
(
algo
));
#else
static_cast
<
cublasGemmAlgo_t
>
(
algo
));
#endif
});
int
algo_bw1
=
Run
(
loops
,
[
=
](
int
algo
)
{
cublas_gemm_ex
(
handle
,
rocblas_operation_none
,
rocblas_operation_transpose
,
K
,
N
,
M
,
&
alpha
,
&
beta
,
A
,
C
,
B
,
#ifdef __HIP_PLATFORM_HCC__
static_cast
<
rocblas_gemm_algo
>
(
algo
));
#else
static_cast
<
cublasGemmAlgo_t
>
(
algo
));
#endif
});
int
algo_bw2
=
Run
(
loops
,
[
=
](
int
algo
)
{
cublas_gemm_ex
(
handle
,
rocblas_operation_none
,
rocblas_operation_none
,
K
,
M
,
N
,
&
alpha
,
&
beta
,
B
,
C
,
A
,
#ifdef __HIP_PLATFORM_HCC__
static_cast
<
rocblas_gemm_algo
>
(
algo
));
#else
static_cast
<
cublasGemmAlgo_t
>
(
algo
));
#endif
});
return
std
::
array
<
int
,
3
>
({
algo_fw
,
algo_bw1
,
algo_bw2
});
}
template
<
typename
Func
>
int
Run
(
int
loops
,
Func
f
)
{
//float fast_latency = std::numeric_limits<float>::max();
float
fast_latency
=
(
std
::
numeric_limits
<
float
>::
max
)();
int
fast_algo
=
0
;
#ifdef __HIP_PLATFORM_HCC__
for
(
int
algo
=
(
int
)
rocblas_gemm_algo_standard
;
algo
<=
(
int
)
rocblas_gemm_algo_standard
;
#else
for
(
int
algo
=
(
int
)
CUBLAS_GEMM_DEFAULT_TENSOR_OP
;
algo
<=
(
int
)
CUBLAS_GEMM_ALGO15_TENSOR_OP
;
#endif
algo
++
)
{
int
warm_up
=
5
;
for
(
int
i
=
0
;
i
<
warm_up
;
++
i
)
f
(
algo
);
hipDeviceSynchronize
();
Stopwatch
timer
;
timer
.
Restart
();
for
(
int
i
=
0
;
i
<
loops
;
++
i
)
f
(
algo
);
hipDeviceSynchronize
();
timer
.
Stop
();
float
avg_latency
=
(
float
)
timer
.
GetTimeInSeconds
()
*
1000
/
loops
;
printf
(
"algo-%d: %.3fms
\n
"
,
algo
,
avg_latency
);
if
(
avg_latency
<
fast_latency
)
{
fast_latency
=
avg_latency
;
fast_algo
=
algo
;
}
}
printf
(
"fast_algo %d: %.3f ms
\n
"
,
fast_algo
,
fast_latency
);
return
fast_algo
;
}
private:
int
M
,
N
,
K
;
rocblas_handle
handle
;
rocblas_operation
transa
,
transb
;
T
*
A
,
*
B
,
*
C
;
};
template
<
typename
T
>
class
StridedGemmTest
{
public:
StridedGemmTest
(
int
b
,
int
m
,
int
n
,
int
k
,
rocblas_operation
ta
,
rocblas_operation
tb
,
rocblas_handle
h
)
:
bsz
(
b
),
M
(
m
),
N
(
n
),
K
(
k
),
transa
(
ta
),
transb
(
tb
),
handle
(
h
)
{
check_cuda_error
(
hipMalloc
((
void
**
)
&
A
,
sizeof
(
T
)
*
M
*
K
*
bsz
));
check_cuda_error
(
hipMalloc
((
void
**
)
&
B
,
sizeof
(
T
)
*
K
*
N
*
bsz
));
check_cuda_error
(
hipMalloc
((
void
**
)
&
C
,
sizeof
(
T
)
*
M
*
N
*
bsz
));
}
~
StridedGemmTest
()
{
check_cuda_error
(
hipFree
(
A
));
check_cuda_error
(
hipFree
(
B
));
check_cuda_error
(
hipFree
(
C
));
}
std
::
array
<
int
,
3
>
TestAlgo
(
int
loops
)
{
float
alpha
=
(
T
)
1.0
f
;
float
beta
=
(
T
)
0.0
f
;
int
algo_fw
=
Run
(
loops
,
[
=
](
int
algo
)
{
int
stride_a
=
M
*
K
;
int
stride_b
=
N
*
K
;
int
stride_c
=
M
*
N
;
cublas_strided_batched_gemm
(
handle
,
M
,
N
,
K
,
&
alpha
,
&
beta
,
A
,
B
,
C
,
transa
,
transb
,
stride_a
,
stride_b
,
stride_c
,
bsz
,
#ifdef __HIP_PLATFORM_HCC__
static_cast
<
rocblas_gemm_algo
>
(
algo
));
#else
static_cast
<
cublasGemmAlgo_t
>
(
algo
));
#endif
});
int
algo_bw1
=
Run
(
loops
,
[
=
](
int
algo
)
{
int
mb
=
(
transa
==
rocblas_operation_transpose
?
K
:
M
);
int
kb
=
(
transa
==
rocblas_operation_transpose
?
M
:
K
);
int
stride_a
=
mb
*
N
;
int
stride_b
=
N
*
kb
;
int
stride_c
=
M
*
K
;
// B need to transpose.
rocblas_operation
op_b
=
(
transb
==
rocblas_operation_transpose
?
rocblas_operation_none
:
rocblas_operation_transpose
);
// Calculate d_A.
cublas_strided_batched_gemm
(
handle
,
mb
,
kb
,
N
,
&
alpha
,
&
beta
,
(
transa
==
rocblas_operation_transpose
?
B
:
C
),
(
transa
==
rocblas_operation_transpose
?
C
:
B
),
A
,
rocblas_operation_none
,
op_b
,
stride_a
,
stride_b
,
stride_c
,
bsz
,
#ifdef __HIP_PLATFORM_HCC__
static_cast
<
rocblas_gemm_algo
>
(
algo
));
#else
static_cast
<
cublasGemmAlgo_t
>
(
algo
));
#endif
});
int
algo_bw2
=
Run
(
loops
,
[
=
](
int
algo
)
{
// A need to transpose.
rocblas_operation
op_a
=
(
transa
==
rocblas_operation_transpose
?
rocblas_operation_none
:
rocblas_operation_transpose
);
int
stride_a
=
M
*
K
;
int
stride_b
=
M
*
N
;
int
stride_c
=
N
*
K
;
// Calculate d_B.
cublas_strided_batched_gemm
(
handle
,
K
,
N
,
M
,
&
alpha
,
&
beta
,
A
,
C
,
B
,
op_a
,
rocblas_operation_none
,
stride_a
,
stride_b
,
stride_c
,
bsz
,
#ifdef __HIP_PLATFORM_HCC__
static_cast
<
rocblas_gemm_algo
>
(
algo
));
#else
static_cast
<
cublasGemmAlgo_t
>
(
algo
));
#endif
});
return
std
::
array
<
int
,
3
>
({
algo_fw
,
algo_bw1
,
algo_bw2
});
}
template
<
typename
Func
>
int
Run
(
int
loops
,
Func
f
)
{
//float fast_latency = std::numeric_limits<float>::max();
float
fast_latency
=
(
std
::
numeric_limits
<
float
>::
max
)();
int
fast_algo
=
0
;
#ifdef __HIP_PLATFORM_HCC__
for
(
int
algo
=
(
int
)
rocblas_gemm_algo_standard
;
algo
<=
(
int
)
rocblas_gemm_algo_standard
;
#else
for
(
int
algo
=
(
int
)
CUBLAS_GEMM_DEFAULT_TENSOR_OP
;
algo
<=
(
int
)
CUBLAS_GEMM_ALGO15_TENSOR_OP
;
#endif
algo
++
)
{
int
warm_up
=
5
;
for
(
int
i
=
0
;
i
<
warm_up
;
++
i
)
f
(
algo
);
hipDeviceSynchronize
();
Stopwatch
timer
;
timer
.
Restart
();
for
(
int
i
=
0
;
i
<
loops
;
++
i
)
f
(
algo
);
hipDeviceSynchronize
();
timer
.
Stop
();
float
avg_latency
=
(
float
)
timer
.
GetTimeInSeconds
()
*
1000
/
loops
;
printf
(
"algo-%d: %.3fms
\n
"
,
algo
,
avg_latency
);
if
(
avg_latency
<
fast_latency
)
{
fast_latency
=
avg_latency
;
fast_algo
=
algo
;
}
}
printf
(
"fast_algo %d: %.3f ms
\n
"
,
fast_algo
,
fast_latency
);
return
fast_algo
;
}
private:
int
bsz
,
M
,
N
,
K
;
rocblas_handle
handle
;
rocblas_operation
transa
,
transb
;
T
*
A
,
*
B
,
*
C
;
};
csrc/includes/hip/gemm_test.h.bak
0 → 100644
View file @
eadbbe09
#pragma once
#include <hip/hip_fp16.h>
#ifndef __HIP_PLATFORM_HCC__
#include <cuda_profiler_api.h>
#endif
#include <array>
#include <cstdio>
#include <cstdlib>
#include <ctime>
#include <limits>
#include <memory>
#include "StopWatch.h"
#include "hip/cublas_wrappers.h"
template <typename T>
void check(T result, char const* const func, const char* const file, int const line)
{
if (result) {
std::cout << (std::string("CUDA runtime error: ") + +file + ":" + std::to_string(line) +
" \n");
}
}
#define check_cuda_error(val) check((val), #val, __FILE__, __LINE__)
template <typename T>
class GemmTest {
public:
GemmTest(int m, int n, int k, rocblas_operation ta, rocblas_operation tb, rocblas_handle h)
: M(m), N(n), K(k), transa(ta), transb(tb), handle(h)
{
check_cuda_error(hipMalloc((void**)&A, sizeof(T) * M * K));
check_cuda_error(hipMalloc((void**)&B, sizeof(T) * K * N));
check_cuda_error(hipMalloc((void**)&C, sizeof(T) * M * N));
}
~GemmTest()
{
check_cuda_error(hipFree(A));
check_cuda_error(hipFree(B));
check_cuda_error(hipFree(C));
}
std::array<int, 3> TestAlgo(int loops)
{
float alpha = (T)1.0f;
float beta = (T)0.0f;
int algo_fw = Run(loops, [=](int algo) {
cublas_gemm_ex(handle,
rocblas_operation_transpose,
rocblas_operation_none,
N,
M,
K,
&alpha,
&beta,
B,
A,
C,
static_cast<cublasGemmAlgo_t>(algo));
});
int algo_bw1 = Run(loops, [=](int algo) {
cublas_gemm_ex(handle,
rocblas_operation_none,
rocblas_operation_transpose,
K,
N,
M,
&alpha,
&beta,
A,
C,
B,
static_cast<cublasGemmAlgo_t>(algo));
});
int algo_bw2 = Run(loops, [=](int algo) {
cublas_gemm_ex(handle,
rocblas_operation_none,
rocblas_operation_none,
K,
M,
N,
&alpha,
&beta,
B,
C,
A,
static_cast<cublasGemmAlgo_t>(algo));
});
return std::array<int, 3>({algo_fw, algo_bw1, algo_bw2});
}
template <typename Func>
int Run(int loops, Func f)
{
float fast_latency = (std::numeric_limits<float>::max)();
int fast_algo = 0;
for (int algo = (int)CUBLAS_GEMM_DEFAULT_TENSOR_OP;
algo <= (int)CUBLAS_GEMM_ALGO15_TENSOR_OP;
algo++) {
int warm_up = 5;
for (int i = 0; i < warm_up; ++i) f(algo);
hipDeviceSynchronize();
Stopwatch timer;
timer.Restart();
for (int i = 0; i < loops; ++i) f(algo);
hipDeviceSynchronize();
timer.Stop();
float avg_latency = (float)timer.GetTimeInSeconds() * 1000 / loops;
printf("algo-%d: %.3fms\n", algo, avg_latency);
if (avg_latency < fast_latency) {
fast_latency = avg_latency;
fast_algo = algo;
}
}
printf("fast_algo %d: %.3f ms\n", fast_algo, fast_latency);
return fast_algo;
}
private:
int M, N, K;
rocblas_handle handle;
rocblas_operation transa, transb;
T *A, *B, *C;
};
template <typename T>
class StridedGemmTest {
public:
StridedGemmTest(int b,
int m,
int n,
int k,
rocblas_operation ta,
rocblas_operation tb,
rocblas_handle h)
: bsz(b), M(m), N(n), K(k), transa(ta), transb(tb), handle(h)
{
check_cuda_error(hipMalloc((void**)&A, sizeof(T) * M * K * bsz));
check_cuda_error(hipMalloc((void**)&B, sizeof(T) * K * N * bsz));
check_cuda_error(hipMalloc((void**)&C, sizeof(T) * M * N * bsz));
}
~StridedGemmTest()
{
check_cuda_error(hipFree(A));
check_cuda_error(hipFree(B));
check_cuda_error(hipFree(C));
}
std::array<int, 3> TestAlgo(int loops)
{
float alpha = (T)1.0f;
float beta = (T)0.0f;
int algo_fw = Run(loops, [=](int algo) {
int stride_a = M * K;
int stride_b = N * K;
int stride_c = M * N;
cublas_strided_batched_gemm(handle,
M,
N,
K,
&alpha,
&beta,
A,
B,
C,
transa,
transb,
stride_a,
stride_b,
stride_c,
bsz,
static_cast<cublasGemmAlgo_t>(algo));
});
int algo_bw1 = Run(loops, [=](int algo) {
int mb = (transa == rocblas_operation_transpose ? K : M);
int kb = (transa == rocblas_operation_transpose ? M : K);
int stride_a = mb * N;
int stride_b = N * kb;
int stride_c = M * K;
// B need to transpose.
rocblas_operation op_b = (transb == rocblas_operation_transpose ? rocblas_operation_none : rocblas_operation_transpose);
// Calculate d_A.
cublas_strided_batched_gemm(handle,
mb,
kb,
N,
&alpha,
&beta,
(transa == rocblas_operation_transpose ? B : C),
(transa == rocblas_operation_transpose ? C : B),
A,
rocblas_operation_none,
op_b,
stride_a,
stride_b,
stride_c,
bsz,
static_cast<cublasGemmAlgo_t>(algo));
});
int algo_bw2 = Run(loops, [=](int algo) {
// A need to transpose.
rocblas_operation op_a = (transa == rocblas_operation_transpose ? rocblas_operation_none : rocblas_operation_transpose);
int stride_a = M * K;
int stride_b = M * N;
int stride_c = N * K;
// Calculate d_B.
cublas_strided_batched_gemm(handle,
K,
N,
M,
&alpha,
&beta,
A,
C,
B,
op_a,
rocblas_operation_none,
stride_a,
stride_b,
stride_c,
bsz,
static_cast<cublasGemmAlgo_t>(algo));
});
return std::array<int, 3>({algo_fw, algo_bw1, algo_bw2});
}
template <typename Func>
int Run(int loops, Func f)
{
float fast_latency = (std::numeric_limits<float>::max)();
int fast_algo = 0;
for (int algo = (int)CUBLAS_GEMM_DEFAULT_TENSOR_OP;
algo <= (int)CUBLAS_GEMM_ALGO15_TENSOR_OP;
algo++) {
int warm_up = 5;
for (int i = 0; i < warm_up; ++i) f(algo);
hipDeviceSynchronize();
Stopwatch timer;
timer.Restart();
for (int i = 0; i < loops; ++i) f(algo);
hipDeviceSynchronize();
timer.Stop();
float avg_latency = (float)timer.GetTimeInSeconds() * 1000 / loops;
printf("algo-%d: %.3fms\n", algo, avg_latency);
if (avg_latency < fast_latency) {
fast_latency = avg_latency;
fast_algo = algo;
}
}
printf("fast_algo %d: %.3f ms\n", fast_algo, fast_latency);
return fast_algo;
}
private:
int bsz, M, N, K;
rocblas_handle handle;
rocblas_operation transa, transb;
T *A, *B, *C;
};
csrc/includes/hip/general_kernels.h
0 → 100644
View file @
eadbbe09
#include <hip/hip_runtime.h>
#include <hip/hip_fp16.h>
#include <stdio.h>
#include <stdlib.h>
#include <hip/hip_cooperative_groups.h>
#include <hiprand_kernel.h>
#include "hip/context.h"
#include "hip/cublas_wrappers.h"
#define THREADS 256
#define TILE_DIM 32
#define minus_infinity -1 * std::numeric_limits<float>::infinity()
#define FINAL_MASK 0xffffffff
template
<
typename
T
>
void
launch_fused_add2
(
T
*
out
,
const
T
*
inp1
,
const
T
*
inp2
,
int
batch_size
,
int
seq_length
,
int
hidden_size
,
hipStream_t
&
stream
);
template
<
typename
T
>
void
launch_fused_add4
(
T
*
out
,
const
T
*
inp1
,
const
T
*
inp2
,
const
T
*
inp3
,
const
T
*
inp4
,
int
batch_size
,
int
seq_length
,
int
hidden_size
,
hipStream_t
&
stream
);
template
<
typename
T
>
void
launch_fused_add3
(
T
*
out
,
const
T
*
inp1
,
const
T
*
inp2
,
const
T
*
inp3
,
int
batch_size
,
int
seq_length
,
int
hidden_size
,
hipStream_t
&
stream
);
csrc/includes/hip/normalize_layer.h
0 → 100644
View file @
eadbbe09
#pragma once
#include <hip/hip_runtime.h>
#include <hip/hip_fp16.h>
#include <stdio.h>
#include <fstream>
#include "hip/custom_hip_layers.h"
using
namespace
std
;
template
<
typename
T
>
class
Normalize_Layer
{
public:
struct
Config
{
uint32_t
batchSize
;
uint32_t
seqLength
;
uint32_t
hiddenDim
;
float
epsilon
;
bool
training
;
bool
useMean
;
Config
(
uint32_t
batch
,
uint32_t
seq
,
uint32_t
h
,
float
epsilon
=
1e-12
,
bool
training
=
true
,
bool
useMean
=
true
)
:
batchSize
(
batch
),
seqLength
(
seq
),
hiddenDim
(
h
),
epsilon
(
epsilon
),
training
(
training
),
useMean
(
useMean
)
{
}
};
Normalize_Layer
(
Config
config
)
:
config_
(
config
),
vars
(
nullptr
),
means
(
nullptr
),
vals_hat
(
nullptr
)
{
}
~
Normalize_Layer
()
{}
void
ForwardCheckpoint
(
int
bsz
,
// batch * seq
T
*
vals
,
const
T
*
residual
,
const
T
*
gamma
,
const
T
*
betta
,
hipStream_t
&
stream
,
bool
preLayerNorm
=
false
)
{
launch_bias_residual_layer_norm
(
vals
,
residual
,
gamma
,
betta
,
config_
.
epsilon
,
bsz
,
config_
.
hiddenDim
,
stream
,
preLayerNorm
,
config_
.
training
,
vars
,
means
);
}
void
Forward
(
int
bsz
,
T
*
vals
,
const
T
*
residual
,
const
T
*
gamma
,
const
T
*
betta
,
hipStream_t
&
stream
,
bool
preLayerNorm
=
false
)
{
launch_bias_residual_layer_norm
(
vals
,
residual
,
gamma
,
betta
,
config_
.
epsilon
,
bsz
,
config_
.
hiddenDim
,
stream
,
preLayerNorm
,
config_
.
training
,
vars
);
}
void
Backward
(
int
bsz
,
const
T
*
out_grad
,
const
T
*
gamma
,
T
*
gamma_grad
,
T
*
betta_grad
,
hipStream_t
stream
[
2
],
T
*
inp_grad_out
,
const
T
*
norm_in
=
nullptr
)
{
launch_layerNorm_backward
(
out_grad
,
norm_in
,
vars
,
means
,
gamma
,
gamma_grad
,
betta_grad
,
inp_grad_out
,
bsz
,
config_
.
hiddenDim
,
stream
);
}
void
Backward
(
int
bsz
,
const
T
*
out_grad
,
const
T
*
gamma
,
const
T
*
betta
,
T
*
gamma_grad
,
T
*
betta_grad
,
hipStream_t
stream
[
2
],
T
*
inp_grad_out
,
const
T
*
norm_out
)
{
launch_layerNorm_backward
(
out_grad
,
norm_out
,
vars
,
gamma
,
gamma_grad
,
betta_grad
,
inp_grad_out
,
bsz
,
config_
.
hiddenDim
,
stream
,
!
config_
.
useMean
,
betta
);
}
void
BackwardFusedAdd
(
int
bsz
,
const
T
*
out_grad1
,
const
T
*
out_grad2
,
const
T
*
gamma
,
T
*
gamma_grad
,
T
*
betta_grad
,
hipStream_t
stream
[
2
],
T
*
inp_grad_out
,
const
T
*
norm_in
=
nullptr
)
{
launch_layerNorm_backward_fused_add
(
out_grad1
,
out_grad2
,
norm_in
,
vars
,
means
,
gamma
,
gamma_grad
,
betta_grad
,
inp_grad_out
,
bsz
,
config_
.
hiddenDim
,
stream
);
}
void
BackwardFusedAdd
(
int
bsz
,
const
T
*
out_grad1
,
const
T
*
out_grad2
,
const
T
*
gamma
,
const
T
*
betta
,
T
*
gamma_grad
,
T
*
betta_grad
,
hipStream_t
stream
[
2
],
T
*
inp_grad_out
,
const
T
*
norm_out
)
{
launch_layerNorm_backward_fused_add
(
out_grad1
,
out_grad2
,
norm_out
,
vars
,
gamma
,
gamma_grad
,
betta_grad
,
inp_grad_out
,
bsz
,
config_
.
hiddenDim
,
stream
,
!
config_
.
useMean
,
betta
);
}
inline
bool
UseMean
()
const
{
return
config_
.
useMean
;
}
inline
void
SetVar
(
T
*
variance
)
{
if
(
!
variance
)
{
throw
std
::
runtime_error
(
"Normalize variance is null."
);
}
vars
=
variance
;
}
inline
void
SetMean
(
T
*
mean
)
{
if
(
!
mean
)
{
throw
std
::
runtime_error
(
"Normalize mean is null."
);
}
means
=
mean
;
}
private:
Config
config_
;
T
*
vars
;
T
*
means
;
T
*
vals_hat
;
};
csrc/includes/hip/softmax.h
0 → 100644
View file @
eadbbe09
#pragma once
#include <hip/hip_runtime.h>
#include <hip/hip_fp16.h>
#include <stdio.h>
#include "hip/custom_hip_layers.h"
#include <fstream>
using
namespace
std
;
template
<
typename
T
>
class
Softmax
{
public:
struct
Config
{
size_t
batchSize
;
size_t
heads
;
size_t
seq_length
;
size_t
prob_depth
;
float
temprature
;
bool
mem_alloc
;
Config
(
size_t
batch
,
size_t
h
,
size_t
seq
,
int
prob_size
=
0
,
bool
mem_alloc
=
false
)
:
batchSize
(
batch
),
heads
(
h
),
seq_length
(
seq
),
prob_depth
(
prob_size
),
temprature
(
1.0
),
mem_alloc
(
mem_alloc
)
{
}
};
Softmax
(
Config
config
)
:
config_
(
config
)
{}
~
Softmax
()
{}
void
Forward
(
int
bsz
,
T
*
vals
,
const
T
*
attn_mask
,
hipStream_t
&
stream
)
{
launch_attn_softmax
<
T
>
(
vals
,
attn_mask
,
bsz
,
config_
.
heads
,
config_
.
seq_length
,
stream
);
}
void
Backward
(
int
bsz
,
T
*
out_grad
,
const
T
*
soft_out
,
hipStream_t
stream
)
{
launch_attn_softmax_backward_v2
<
T
>
(
out_grad
,
soft_out
,
bsz
,
config_
.
heads
,
config_
.
seq_length
,
stream
);
}
inline
size_t
GetProbDepth
()
const
{
return
config_
.
prob_depth
;
}
inline
size_t
GetBatchSize
()
const
{
return
config_
.
batchSize
;
}
inline
size_t
GetNumHeads
()
const
{
return
config_
.
heads
;
}
inline
size_t
GetSeqLength
()
const
{
return
config_
.
seq_length
;
}
inline
void
SetSeqLength
(
size_t
seq_len
)
{
config_
.
seq_length
=
seq_len
;
}
private:
Config
config_
;
};
csrc/includes/hip/strided_batch_gemm.h
0 → 100644
View file @
eadbbe09
#pragma once
#include <hip/hip_runtime.h>
#include <hip/hip_fp16.h>
#include <stdio.h>
#include "hip/context.h"
template
<
typename
T
>
class
StridedBatchGemm
{
public:
struct
Config
{
int
batch_size
;
int
m
;
int
n
;
int
k
;
float
alpha
;
float
beta
;
rocblas_operation
op_A
;
rocblas_operation
op_B
;
std
::
array
<
int
,
3
>
gemm_algos
;
Config
(
int
batch
,
int
mm
,
int
nn
,
int
kk
,
float
param_alpha
,
float
param_beta
,
rocblas_operation
opA
,
rocblas_operation
opB
,
const
std
::
array
<
int
,
3
>&
algos
)
:
batch_size
(
batch
),
m
(
mm
),
n
(
nn
),
k
(
kk
),
alpha
(
param_alpha
),
beta
(
param_beta
),
op_A
(
opA
),
op_B
(
opB
),
gemm_algos
(
algos
)
{
}
void
SetConfig
(
int
mm
,
int
nn
,
int
kk
)
{
m
=
mm
;
n
=
nn
;
k
=
kk
;
}
};
StridedBatchGemm
(
const
Config
&
config
)
:
_config
(
config
)
{}
virtual
~
StridedBatchGemm
()
{}
void
Forward
(
int
bsz
,
T
*
output
,
const
T
*
_buffer_a
,
const
T
*
_buffer_b
,
rocblas_handle
handle
)
{
int
stride_a
=
_config
.
m
*
_config
.
k
;
int
stride_b
=
_config
.
n
*
_config
.
k
;
int
stride_c
=
_config
.
m
*
_config
.
n
;
cublas_strided_batched_gemm
(
handle
,
//rocblas_sgemm_strided_batched(handle,
_config
.
m
,
_config
.
n
,
_config
.
k
,
&
_config
.
alpha
,
&
_config
.
beta
,
_buffer_a
,
_buffer_b
,
output
,
_config
.
op_A
,
_config
.
op_B
,
stride_a
,
stride_b
,
stride_c
,
bsz
,
rocblas_gemm_algo
(
_config
.
gemm_algos
[
0
]));
//rocblas_sgemm_strided_batched(handle,
}
void
ForwardPlusSave
(
T
*
output
,
const
T
*
_buffer_a
,
const
T
*
_buffer_b
,
rocblas_handle
handle
)
{
int
stride_a
=
_config
.
m
*
_config
.
k
;
int
stride_b
=
_config
.
n
*
_config
.
k
;
int
stride_c
=
_config
.
m
*
_config
.
n
;
cublas_strided_batched_gemm
(
handle
,
_config
.
m
,
_config
.
n
,
_config
.
k
,
&
_config
.
alpha
,
&
_config
.
beta
,
_buffer_a
,
_buffer_b
,
output
,
_config
.
op_A
,
_config
.
op_B
,
stride_a
,
stride_b
,
stride_c
,
_config
.
batch_size
,
//cublasGemmAlgo_t(_config.gemm_algos[0]));
rocblas_gemm_algo
(
_config
.
gemm_algos
[
0
]));
k_buf
=
_buffer_a
;
q_buf
=
_buffer_b
;
}
void
Backward
(
int
bsz
,
const
T
*
d_output
,
const
T
*
_buffer_a
,
const
T
*
_buffer_b
,
rocblas_handle
handle
,
T
*
inpGradA
=
nullptr
,
T
*
inpGradB
=
nullptr
)
{
int
mb
=
(
_config
.
op_A
==
rocblas_operation_transpose
?
_config
.
k
:
_config
.
m
);
int
kb
=
(
_config
.
op_A
==
rocblas_operation_transpose
?
_config
.
m
:
_config
.
k
);
int
stride_a
=
mb
*
_config
.
n
;
int
stride_b
=
_config
.
n
*
kb
;
int
stride_c
=
_config
.
m
*
_config
.
k
;
// B need to transpose.
rocblas_operation
op_b
=
(
_config
.
op_B
==
rocblas_operation_transpose
?
rocblas_operation_none
:
rocblas_operation_transpose
);
// Calculate d_A.
cublas_strided_batched_gemm
(
handle
,
//rocblas_sgemm_strided_batched(handle,
mb
,
kb
,
_config
.
n
,
&
_config
.
alpha
,
&
_config
.
beta
,
(
_config
.
op_A
==
rocblas_operation_transpose
?
_buffer_b
:
d_output
),
(
_config
.
op_A
==
rocblas_operation_transpose
?
d_output
:
_buffer_b
),
inpGradA
,
rocblas_operation_none
,
op_b
,
stride_a
,
stride_b
,
stride_c
,
bsz
,
//cublasGemmAlgo_t(_config.gemm_algos[1]));
rocblas_gemm_algo
(
_config
.
gemm_algos
[
1
]));
// A need to transpose.
rocblas_operation
op_a
=
(
_config
.
op_A
==
rocblas_operation_transpose
?
rocblas_operation_none
:
rocblas_operation_transpose
);
stride_a
=
_config
.
m
*
_config
.
k
;
stride_b
=
_config
.
m
*
_config
.
n
;
stride_c
=
_config
.
n
*
_config
.
k
;
// Calculate d_B.
cublas_strided_batched_gemm
(
handle
,
//rocblas_sgemm_strided_batched(handle,
_config
.
k
,
_config
.
n
,
_config
.
m
,
&
_config
.
alpha
,
&
_config
.
beta
,
_buffer_a
,
d_output
,
inpGradB
,
op_a
,
rocblas_operation_none
,
stride_a
,
stride_b
,
stride_c
,
bsz
,
//cublasGemmAlgo_t(_config.gemm_algos[2]));
rocblas_gemm_algo
(
_config
.
gemm_algos
[
2
]));
}
inline
int
GetN
()
const
{
return
_config
.
k
;
}
inline
const
T
*
GetBufferA
()
const
{
return
k_buf
;
}
inline
const
T
*
GetBufferB
()
const
{
return
q_buf
;
}
inline
void
SetConfig
(
int
m
,
int
n
,
int
k
)
{
_config
.
SetConfig
(
m
,
n
,
k
);
}
private:
Config
_config
;
const
T
*
q_buf
;
const
T
*
k_buf
;
};
csrc/includes/hip/strided_batch_gemm.h.bak
0 → 100644
View file @
eadbbe09
#pragma once
#include <hip/hip_runtime.h>
#include <hip/hip_fp16.h>
#include <stdio.h>
#include "hip/context.h"
template <typename T>
class StridedBatchGemm {
public:
struct Config {
int batch_size;
int m;
int n;
int k;
float alpha;
float beta;
rocblas_operation op_A;
rocblas_operation op_B;
std::array<int, 3> gemm_algos;
Config(int batch,
int mm,
int nn,
int kk,
float param_alpha,
float param_beta,
rocblas_operation opA,
rocblas_operation opB,
const std::array<int, 3>& algos)
: batch_size(batch),
m(mm),
n(nn),
k(kk),
alpha(param_alpha),
beta(param_beta),
op_A(opA),
op_B(opB),
gemm_algos(algos)
{
}
void SetConfig(int mm, int nn, int kk)
{
m = mm;
n = nn;
k = kk;
}
};
StridedBatchGemm(const Config& config) : _config(config) {}
virtual ~StridedBatchGemm() {}
void Forward(int bsz, T* output, const T* _buffer_a, const T* _buffer_b, rocblas_handle handle)
{
int stride_a = _config.m * _config.k;
int stride_b = _config.n * _config.k;
int stride_c = _config.m * _config.n;
cublas_strided_batched_gemm(handle,
_config.m,
_config.n,
_config.k,
&_config.alpha,
&_config.beta,
_buffer_a,
_buffer_b,
output,
_config.op_A,
_config.op_B,
stride_a,
stride_b,
stride_c,
bsz,
cublasGemmAlgo_t(_config.gemm_algos[0]));
}
void ForwardPlusSave(T* output, const T* _buffer_a, const T* _buffer_b, rocblas_handle handle)
{
int stride_a = _config.m * _config.k;
int stride_b = _config.n * _config.k;
int stride_c = _config.m * _config.n;
cublas_strided_batched_gemm(handle,
_config.m,
_config.n,
_config.k,
&_config.alpha,
&_config.beta,
_buffer_a,
_buffer_b,
output,
_config.op_A,
_config.op_B,
stride_a,
stride_b,
stride_c,
_config.batch_size,
cublasGemmAlgo_t(_config.gemm_algos[0]));
k_buf = _buffer_a;
q_buf = _buffer_b;
}
void Backward(int bsz,
const T* d_output,
const T* _buffer_a,
const T* _buffer_b,
rocblas_handle handle,
T* inpGradA = nullptr,
T* inpGradB = nullptr)
{
int mb = (_config.op_A == rocblas_operation_transpose ? _config.k : _config.m);
int kb = (_config.op_A == rocblas_operation_transpose ? _config.m : _config.k);
int stride_a = mb * _config.n;
int stride_b = _config.n * kb;
int stride_c = _config.m * _config.k;
// B need to transpose.
rocblas_operation op_b = (_config.op_B == rocblas_operation_transpose ? rocblas_operation_none : rocblas_operation_transpose);
// Calculate d_A.
cublas_strided_batched_gemm(handle,
mb,
kb,
_config.n,
&_config.alpha,
&_config.beta,
(_config.op_A == rocblas_operation_transpose ? _buffer_b : d_output),
(_config.op_A == rocblas_operation_transpose ? d_output : _buffer_b),
inpGradA,
rocblas_operation_none,
op_b,
stride_a,
stride_b,
stride_c,
bsz,
cublasGemmAlgo_t(_config.gemm_algos[1]));
// A need to transpose.
rocblas_operation op_a = (_config.op_A == rocblas_operation_transpose ? rocblas_operation_none : rocblas_operation_transpose);
stride_a = _config.m * _config.k;
stride_b = _config.m * _config.n;
stride_c = _config.n * _config.k;
// Calculate d_B.
cublas_strided_batched_gemm(handle,
_config.k,
_config.n,
_config.m,
&_config.alpha,
&_config.beta,
_buffer_a,
d_output,
inpGradB,
op_a,
rocblas_operation_none,
stride_a,
stride_b,
stride_c,
bsz,
cublasGemmAlgo_t(_config.gemm_algos[2]));
}
inline int GetN() const { return _config.k; }
inline const T* GetBufferA() const { return k_buf; }
inline const T* GetBufferB() const { return q_buf; }
inline void SetConfig(int m, int n, int k) { _config.SetConfig(m, n, k); }
private:
Config _config;
const T* q_buf;
const T* k_buf;
};
csrc/includes/hip/type_shim.h
0 → 100644
View file @
eadbbe09
#include "hip/hip_runtime.h"
/* Taken from NVIDIA/apex commit 855808f3fc268e9715d613f3c2e56469d8c986d8 */
#include <ATen/ATen.h>
// Forward/backward compatiblity hack around
// https://github.com/pytorch/pytorch/commit/3aeb78079bcd68282fe9117088e138b77318e288
// pending more future-proof guidance from upstream.
// struct TypeShim
// {
// const at::Type& payload;
// TypeShim(const at::Type& type) : payload(type) {}
// // Enable trivial conversion to a const at::Type& for pre-3aeb78
// operator const at::Type&(){ return payload; };
// // Enable dispatch switch statements to take *this directly for post-3aeb78
// //operator at::ScalarType(){ return payload.; };
// };
#define DISPATCH_FLOAT_AND_HALF(TYPE, LEVEL, NAME, ...) \
switch (TYPE) { \
case at::ScalarType::Float: { \
using scalar_t_##LEVEL = float; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Half: { \
using scalar_t_##LEVEL = at::Half; \
__VA_ARGS__; \
break; \
} \
default: AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
}
#define DISPATCH_DOUBLE_FLOAT_AND_HALF(TYPE, LEVEL, NAME, ...) \
switch (TYPE) { \
case at::ScalarType::Double: { \
using scalar_t_##LEVEL = double; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Float: { \
using scalar_t_##LEVEL = float; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Half: { \
using scalar_t_##LEVEL = at::Half; \
__VA_ARGS__; \
break; \
} \
default: AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
}
#define DISPATCH_DOUBLE_AND_FLOAT(TYPE, LEVEL, NAME, ...) \
switch (TYPE) { \
case at::ScalarType::Double: { \
using scalar_t_##LEVEL = double; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Float: { \
using scalar_t_##LEVEL = float; \
__VA_ARGS__; \
break; \
} \
default: AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
}
template
<
typename
T
>
__device__
__forceinline__
T
reduce_block_into_lanes
(
T
*
x
,
T
val
,
int
lanes
=
1
,
bool
share_result
=
false
)
// lanes is intended to be <= 32.
{
int
tid
=
threadIdx
.
x
+
threadIdx
.
y
*
blockDim
.
x
;
int
blockSize
=
blockDim
.
x
*
blockDim
.
y
;
// blockSize is intended to be a multiple of 32.
if
(
blockSize
>=
64
)
{
x
[
tid
]
=
val
;
__syncthreads
();
}
#pragma unroll
for
(
int
i
=
(
blockSize
>>
1
);
i
>=
64
;
i
>>=
1
)
{
if
(
tid
<
i
)
x
[
tid
]
=
x
[
tid
]
+
x
[
tid
+
i
];
__syncthreads
();
}
T
final
;
if
(
tid
<
32
)
{
if
(
blockSize
>=
64
)
final
=
x
[
tid
]
+
x
[
tid
+
32
];
else
final
=
val
;
// __SYNCWARP();
#pragma unroll
for
(
int
i
=
16
;
i
>=
lanes
;
i
>>=
1
)
final
=
final
+
__shfl_down_sync
(
0xffffffff
,
final
,
i
);
}
if
(
share_result
)
{
if
(
tid
<
lanes
)
x
[
tid
]
=
final
;
// EpilogueOp
// Make sure the smem result is visible to all warps.
__syncthreads
();
}
return
final
;
}
csrc/lamb/hip/fused_lamb_hip.cpp
0 → 100644
View file @
eadbbe09
/* Copyright 2019 The Microsoft DeepSpeed Team */
#include <torch/extension.h>
// CUDA forward declaration
void
fused_lamb_cuda
(
at
::
Tensor
&
p
,
at
::
Tensor
&
p_copy
,
at
::
Tensor
&
m
,
at
::
Tensor
&
v
,
at
::
Tensor
&
g
,
float
lr
,
float
beta1
,
float
beta2
,
float
max_coeff
,
float
min_coeff
,
float
eps
,
float
grad_scale
,
int
step
,
int
mode
,
int
bias_correction
,
float
decay
,
at
::
Tensor
&
w_l2_i
,
at
::
Tensor
&
u_l2_i
,
at
::
Tensor
&
lamb_coeff_val
);
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) \
CHECK_CUDA(x); \
CHECK_CONTIGUOUS(x)
// C++ interface
at
::
Tensor
lamb
(
at
::
Tensor
&
p
,
at
::
Tensor
&
p_copy
,
at
::
Tensor
&
m
,
at
::
Tensor
&
v
,
at
::
Tensor
&
g
,
float
lr
,
float
beta1
,
float
beta2
,
float
max_coeff
,
float
min_coeff
,
float
eps
,
float
grad_scale
,
int
step
,
int
mode
,
int
bias_correction
,
float
decay
)
{
CHECK_INPUT
(
p
);
if
(
p_copy
.
numel
()
>
0
)
CHECK_INPUT
(
p_copy
);
CHECK_INPUT
(
m
);
CHECK_INPUT
(
v
);
CHECK_INPUT
(
g
);
int64_t
num_elem
=
p
.
numel
();
AT_ASSERTM
(
m
.
numel
()
==
num_elem
,
"number of elements in m and p tensors should be equal"
);
AT_ASSERTM
(
v
.
numel
()
==
num_elem
,
"number of elements in v and p tensors should be equal"
);
AT_ASSERTM
(
g
.
numel
()
==
num_elem
,
"number of elements in g and p tensors should be equal"
);
AT_ASSERTM
(
p_copy
.
numel
()
==
num_elem
||
p_copy
.
numel
()
==
0
,
"number of elements in p_copy and p tensors should be equal, or p_copy should be empty"
);
// intermediate for weight L2 reduction
// make sure that the threads per block is at least 512 during the kernel launch otherwise the
// behavious is unexpected
at
::
Tensor
w_l2_i
=
at
::
empty
(
{
512
},
p
.
options
().
dtype
(
p
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
?
at
::
ScalarType
::
Float
:
p
.
type
().
scalarType
()));
// intermediate for update L2 reduction
// make sure that the threads per block is at least 512 during the kernel launch otherwise the
// behavious is unexpected
at
::
Tensor
u_l2_i
=
at
::
empty
(
{
512
},
p
.
options
().
dtype
(
p
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
?
at
::
ScalarType
::
Float
:
p
.
type
().
scalarType
()));
at
::
Tensor
lamb_coeff_val
=
at
::
empty
(
{
1
},
p
.
options
().
dtype
(
p
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
?
at
::
ScalarType
::
Float
:
p
.
type
().
scalarType
()));
fused_lamb_cuda
(
p
,
p_copy
,
m
,
v
,
g
,
lr
,
beta1
,
beta2
,
max_coeff
,
min_coeff
,
eps
,
grad_scale
,
step
,
mode
,
bias_correction
,
decay
,
w_l2_i
,
u_l2_i
,
lamb_coeff_val
);
return
lamb_coeff_val
;
}
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"lamb"
,
&
lamb
,
"Adam optimized CUDA implementation with LAMB."
);
}
csrc/lamb/hip/fused_lamb_hip_kernel.hip
0 → 100644
View file @
eadbbe09
/* Copyright 2019 The Microsoft DeepSpeed Team */
#include <hip/hip_runtime.h>
#include <hip/hip_runtime.h>
#include <stdio.h>
#include <cmath>
#include "ATen/ATen.h"
#include "ATen/TensorUtils.h"
#include "ATen/hip/HIPContext.h"
#include "ATen/hip/detail/IndexUtils.cuh"
//#include "ATen/Type.h"
#include <THH/THHGeneral.h>
#include "ATen/AccumulateType.h"
#include <iostream>
//#include <helper_functions.h>
#if defined(__HIP_PLATFORM_HCC__) && HIP_VERSION > 305
#include <hip/hip_cooperative_groups.h>
#else
#include <cooperative_groups.h>
#endif
#include <hip/hip_runtime_api.h>
#include <stdio.h>
namespace cg = cooperative_groups;
// Utility class used to avoid linker errors with extern
// unsized shared memory arrays with templated type
namespace {
// This is the un-specialized struct. Note that we prevent instantiation of this
// struct by putting an undefined symbol in the function body so it won't compile.
template <typename T>
struct SharedMemory {
// Ensure that we won't compile any un-specialized types
__device__ inline operator T*()
{
extern __device__ void error(void);
error();
return NULL;
}
};
template <>
struct SharedMemory<float> {
__device__ inline operator float*()
{
HIP_DYNAMIC_SHARED( float, s_float)
return s_float;
}
};
template <>
struct SharedMemory<double> {
__device__ inline operator double*()
{
HIP_DYNAMIC_SHARED( double, s_double)
return s_double;
}
};
} // namespace
#include "hip/type_shim.h"
//#include "type_shim.h"
typedef enum {
ADAM_MODE_0 = 0, // eps under square root
ADAM_MODE_1 = 1 // eps outside square root
} adamMode_t;
// s_a and s_b are in shared memory
// g_a and g_b are in shared memory
template <typename T, int blockSize>
__device__ void reduce_block_in_shared_memory(T* s_a, T* s_b, T* g_a, T* g_b)
{
// Handle to thread block group
cg::thread_block cta = cg::this_thread_block();
// perform block reduction in shared memory,
unsigned int tid = cta.thread_rank();
T a_sum = s_a[tid];
T b_sum = s_b[tid];
#if defined(__HIP_PLATFORM_HCC__) && HIP_VERSION > 305
cta.sync();
#else
cg::sync(cta);
#endif
// do reduction in shared mem
if ((blockSize >= 512) && (tid < 256)) {
s_a[tid] = a_sum = a_sum + s_a[tid + 256];
s_b[tid] = b_sum = b_sum + s_b[tid + 256];
}
#if defined(__HIP_PLATFORM_HCC__) && HIP_VERSION > 305
cta.sync();
#else
cg::sync(cta);
#endif
if ((blockSize >= 256) && (tid < 128)) {
s_a[tid] = a_sum = a_sum + s_a[tid + 128];
s_b[tid] = b_sum = b_sum + s_b[tid + 128];
}
#if defined(__HIP_PLATFORM_HCC__) && HIP_VERSION > 305
cta.sync();
#else
cg::sync(cta);
#endif
if ((blockSize >= 128) && (tid < 64)) {
s_a[tid] = a_sum = a_sum + s_a[tid + 64];
s_b[tid] = b_sum = b_sum + s_b[tid + 64];
}
#if defined(__HIP_PLATFORM_HCC__) && HIP_VERSION > 305
cta.sync();
#else
cg::sync(cta);
#endif
#if (__CUDA_ARCH__ >= 300)
if (tid < 32) {
cg::coalesced_group active = cg::coalesced_threads();
// Fetch final intermediate sum from 2nd warp
if (blockSize >= 64) {
a_sum = a_sum + s_a[tid + 32];
b_sum = b_sum + s_b[tid + 32];
}
// Reduce final warp using shuffle
for (int offset = warpSize / 2; offset > 0; offset /= 2) {
a_sum += active.shfl_down(a_sum, offset);
b_sum += active.shfl_down(b_sum, offset);
}
}
#else
if ((blockSize >= 64) && (tid < 32)) {
s_a[tid] = a_sum = a_sum + s_a[tid + 32];
s_b[tid] = b_sum = b_sum + s_b[tid + 32];
}
#if defined(__HIP_PLATFORM_HCC__) && HIP_VERSION > 305
cta.sync();
#else
cg::sync(cta);
#endif
if ((blockSize >= 32) && (tid < 16)) {
s_a[tid] = a_sum = a_sum + s_a[tid + 16];
s_b[tid] = b_sum = b_sum + s_b[tid + 16];
}
#if defined(__HIP_PLATFORM_HCC__) && HIP_VERSION > 305
cta.sync();
#else
cg::sync(cta);
#endif
if ((blockSize >= 16) && (tid < 8)) {
s_a[tid] = a_sum = a_sum + s_a[tid + 8];
s_b[tid] = b_sum = b_sum + s_b[tid + 8];
}
#if defined(__HIP_PLATFORM_HCC__) && HIP_VERSION > 305
cta.sync();
#else
cg::sync(cta);
#endif
if ((blockSize >= 8) && (tid < 4)) {
s_a[tid] = a_sum = a_sum + s_a[tid + 4];
s_b[tid] = b_sum = b_sum + s_b[tid + 4];
}
#if defined(__HIP_PLATFORM_HCC__) && HIP_VERSION > 305
cta.sync();
#else
cg::sync(cta);
#endif
if ((blockSize >= 4) && (tid < 2)) {
s_a[tid] = a_sum = a_sum + s_a[tid + 2];
s_b[tid] = b_sum = b_sum + s_b[tid + 2];
}
#if defined(__HIP_PLATFORM_HCC__) && HIP_VERSION > 305
cta.sync();
#else
cg::sync(cta);
#endif
if ((blockSize >= 2) && (tid < 1)) {
s_a[tid] = a_sum = a_sum + s_a[tid + 1];
s_b[tid] = b_sum = b_sum + s_b[tid + 1];
}
#if defined(__HIP_PLATFORM_HCC__) && HIP_VERSION > 305
cta.sync();
#else
cg::sync(cta);
#endif
#endif
// write result for this block to global mem
if (tid == 0) {
g_a[blockIdx.x] = (T)a_sum;
g_b[blockIdx.x] = (T)b_sum;
}
}
template <typename T, int blockSize>
__device__ void reduce_two_vectors_in_register(T a, T b, T* g_a, T* g_b)
{
const int threadIdInBlock = cg::this_thread_block().thread_rank();
T* s_a = SharedMemory<T>();
T* s_b = SharedMemory<T>() + cg::this_thread_block().size();
s_a[threadIdInBlock] = a;
s_b[threadIdInBlock] = b;
reduce_block_in_shared_memory<T, blockSize>(s_a, s_b, g_a, g_b);
}
template <typename T, typename GRAD_T, int blockSize>
__global__ void lamb_cuda_kernel_part1(
T* __restrict__ p,
GRAD_T* __restrict__ p_copy, // For mixed precision training, pass NULL if not needed
T* __restrict__ m,
T* __restrict__ v,
const GRAD_T* __restrict__ g,
const float b1,
const float b2,
const float eps,
const float grad_scale,
const float step_size,
const size_t tsize,
adamMode_t mode,
const float decay,
T* __restrict__ w_l2_i,
T* __restrict__ u_l2_i)
{
// Assuming 2D grids and 2D blocks
const int blockId = gridDim.x * blockIdx.y + blockIdx.x;
const int threadsPerBlock = blockDim.x * blockDim.y;
const int threadIdInBlock = cg::this_thread_block().thread_rank();
const int i = (blockId * threadsPerBlock + threadIdInBlock);
const int totThreads = gridDim.x * gridDim.y * threadsPerBlock;
T reg_w = 0;
T reg_u = 0;
for (int j = i; j < tsize; j += totThreads) {
T scaled_grad = g[j] / grad_scale;
T pj = p[j];
m[j] = b1 * m[j] + (1 - b1) * scaled_grad;
v[j] = b2 * v[j] + (1 - b2) * scaled_grad * scaled_grad;
float denom;
if (mode == ADAM_MODE_0)
denom = sqrtf(v[j] + eps);
else // Mode 1
denom = sqrtf(v[j]) + eps;
T update = (m[j] / denom) + (decay * p[j]);
reg_u += update * update;
reg_w += pj * pj;
}
reduce_two_vectors_in_register<T, blockSize>(reg_w, reg_u, w_l2_i, u_l2_i);
}
template <typename T, typename GRAD_T, int blockSize>
__global__ void lamb_cuda_kernel_part2(const size_t tsize, T* __restrict__ g_a, T* __restrict__ g_b)
{
T* s_a = SharedMemory<T>();
T* s_b = SharedMemory<T>() + cg::this_thread_block().size();
const int threadIdInBlock = cg::this_thread_block().thread_rank();
s_a[threadIdInBlock] = g_a[threadIdInBlock];
s_b[threadIdInBlock] = g_b[threadIdInBlock];
if (threadIdInBlock >= tsize) {
s_a[threadIdInBlock] = 0.0;
s_b[threadIdInBlock] = 0.0;
}
reduce_block_in_shared_memory<T, blockSize>(s_a, s_b, g_a, g_b);
}
template <typename T, typename GRAD_T>
__global__ void lamb_cuda_kernel_part3(
T* __restrict__ p,
GRAD_T* __restrict__ p_copy, // For mixed precision training, pass NULL if not needed
T* __restrict__ m,
T* __restrict__ v,
const GRAD_T* __restrict__ g,
const float b1,
const float b2,
const float max_coeff,
const float min_coeff,
const float eps,
const float grad_scale,
const float step_size,
const size_t tsize,
adamMode_t mode,
const float decay,
T* __restrict__ w_l2_i,
T* __restrict__ u_l2_i,
T* __restrict__ lamb_coeff_val)
{
// Assuming 2D grids and 2D blocks
const int blockId = gridDim.x * blockIdx.y + blockIdx.x;
const int threadsPerBlock = blockDim.x * blockDim.y;
const int threadIdInBlock = cg::this_thread_block().thread_rank();
const int i = (blockId * threadsPerBlock + threadIdInBlock);
const int totThreads = gridDim.x * gridDim.y * threadsPerBlock;
T reg_w = sqrtf(w_l2_i[0]);
T reg_u = sqrtf(u_l2_i[0]);
float lamb_coeff = 1.0;
if (reg_w != 0 and reg_u != 0) {
lamb_coeff = reg_w / reg_u;
if (lamb_coeff > max_coeff) { lamb_coeff = max_coeff; }
if (lamb_coeff < min_coeff) { lamb_coeff = min_coeff; }
}
if (blockId == 0 and threadIdInBlock == 0) {
lamb_coeff_val[0] = lamb_coeff;
// printf("Cuda Lamb Coeff is %.6f \n",lamb_coeff);
}
for (int j = i; j < tsize; j += totThreads) {
T pj = (float)p[j];
T mj = m[j];
T vj = v[j];
float denom;
if (mode == ADAM_MODE_0)
denom = sqrtf(vj + eps);
else // Mode 1
denom = sqrtf(vj) + eps;
T update = (mj / denom) + (decay * pj);
pj = pj - (step_size * lamb_coeff * update);
p[j] = pj;
if (p_copy != NULL) p_copy[j] = (GRAD_T)pj;
}
}
void fused_lamb_cuda(at::Tensor& p,
at::Tensor& p_copy,
at::Tensor& m,
at::Tensor& v,
at::Tensor& g,
float lr,
float beta1,
float beta2,
float max_coeff,
float min_coeff,
float eps,
float grad_scale,
int step,
int mode,
int bias_correction,
float decay,
at::Tensor& w_l2_i,
at::Tensor& u_l2_i,
at::Tensor& lamb_coeff)
{
// using namespace at;
// Get tensor size
int tsize = p.numel();
// Determine #threads and #blocks
const int threadsPerBlock = 512;
int num_blocks = (tsize + threadsPerBlock - 1) / threadsPerBlock;
if (num_blocks > 512) num_blocks = 512;
int smemsize = 0;
if (p.type().scalarType() == at::ScalarType::Double)
smemsize = 2 * threadsPerBlock * sizeof(double);
else
smemsize = 2 * threadsPerBlock * sizeof(float);
const dim3 blocks(num_blocks);
const dim3 threads(threadsPerBlock);
AT_ASSERTM(at::cuda::detail::canUse32BitIndexMath(p),
"parameter tensor is too large to be indexed with int32");
// Constants
float step_size = 0;
if (bias_correction == 1) {
const float bias_correction1 = 1 - ::pow(beta1, step);
const float bias_correction2 = 1 - ::pow(beta2, step);
step_size = lr * std::sqrt(bias_correction2) / bias_correction1;
} else {
step_size = lr;
}
hipStream_t stream = at::hip::getCurrentHIPStreamMasqueradingAsCUDA();
if (g.type().scalarType() == at::ScalarType::Half) {
// all other values should be fp32 for half gradients
AT_ASSERTM(p.type().scalarType() == at::ScalarType::Float,
"expected parameter to be of float type");
// dispatch is done on the gradient type
using namespace at; // prevents "toString is undefined" errors
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
g.scalar_type(), "lamb_cuda_kernel", ([&] {
using accscalar_t = at::acc_type<scalar_t, true>;
hipLaunchKernelGGL(( lamb_cuda_kernel_part1<accscalar_t, scalar_t, threadsPerBlock>)
, dim3(blocks), dim3(threadsPerBlock), smemsize, stream,
p.data<accscalar_t>(),
p_copy.numel() ? p_copy.data<scalar_t>() : NULL,
m.data<accscalar_t>(),
v.data<accscalar_t>(),
g.data<scalar_t>(),
beta1,
beta2,
eps,
grad_scale,
step_size,
tsize,
(adamMode_t)mode,
decay,
w_l2_i.data<accscalar_t>(),
u_l2_i.data<accscalar_t>());
hipLaunchKernelGGL(( lamb_cuda_kernel_part2<accscalar_t, scalar_t, threadsPerBlock>)
, dim3(1), dim3(threadsPerBlock), smemsize, stream,
num_blocks, w_l2_i.data<accscalar_t>(), u_l2_i.data<accscalar_t>());
hipLaunchKernelGGL(( lamb_cuda_kernel_part3<accscalar_t, scalar_t>)
, dim3(blocks), dim3(threadsPerBlock), smemsize, stream,
p.data<accscalar_t>(),
p_copy.numel() ? p_copy.data<scalar_t>() : NULL,
m.data<accscalar_t>(),
v.data<accscalar_t>(),
g.data<scalar_t>(),
beta1,
beta2,
max_coeff,
min_coeff,
eps,
grad_scale,
step_size,
tsize,
(adamMode_t)mode,
decay,
w_l2_i.data<accscalar_t>(),
u_l2_i.data<accscalar_t>(),
lamb_coeff.data<accscalar_t>());
}));
} else {
using namespace at;
AT_DISPATCH_FLOATING_TYPES(
g.scalar_type(), "lamb_cuda_kernel", ([&] {
hipLaunchKernelGGL(( lamb_cuda_kernel_part1<scalar_t, scalar_t, threadsPerBlock>)
, dim3(blocks), dim3(threadsPerBlock), smemsize, stream,
p.data<scalar_t>(),
NULL, // don't output p_copy for fp32, it's wasted write
m.data<scalar_t>(),
v.data<scalar_t>(),
g.data<scalar_t>(),
beta1,
beta2,
eps,
grad_scale,
step_size,
tsize,
(adamMode_t)mode,
decay,
w_l2_i.data<scalar_t>(),
u_l2_i.data<scalar_t>());
hipLaunchKernelGGL(( lamb_cuda_kernel_part2<scalar_t, scalar_t, threadsPerBlock>)
, dim3(1), dim3(threadsPerBlock), smemsize, stream,
num_blocks, w_l2_i.data<scalar_t>(), u_l2_i.data<scalar_t>());
hipLaunchKernelGGL(( lamb_cuda_kernel_part3<scalar_t, scalar_t>)
, dim3(blocks), dim3(threadsPerBlock), smemsize, stream,
p.data<scalar_t>(),
NULL, // don't output p_copy for fp32, it's wasted write
m.data<scalar_t>(),
v.data<scalar_t>(),
g.data<scalar_t>(),
beta1,
beta2,
max_coeff,
min_coeff,
eps,
grad_scale,
step_size,
tsize,
(adamMode_t)mode,
decay,
w_l2_i.data<scalar_t>(),
u_l2_i.data<scalar_t>(),
lamb_coeff.data<scalar_t>());
}));
}
THCudaCheck(hipGetLastError());
}
// template __device__ void reduce_two_vectors_in_register<float,512>(float a, float b, float* g_a,
// float* g_b, cg::grid_group &cgg);
csrc/sparse_attention/hip/utils.cpp
0 → 100644
View file @
eadbbe09
// DeepSpeed note, code taken & adapted from commit 9aa94789f13ada713af36cfd8cca2fc9a7f6b79a
// https://github.com/ptillet/torch-blocksparse/blob/master/csrc/utils.cpp
#include <torch/extension.h>
#include <string>
#include <tuple>
#include <vector>
#ifdef _OPENMP
#include <omp.h>
#endif
typedef
std
::
vector
<
std
::
tuple
<
int
,
torch
::
Tensor
>>
ret_t
;
void
segment_blocks
(
torch
::
Tensor
layout
,
torch
::
Tensor
idx
,
torch
::
Tensor
scratch
,
int
max_width
,
ret_t
&
ret
)
{
size_t
H
=
layout
.
size
(
0
);
size_t
M
=
layout
.
size
(
1
);
size_t
N
=
layout
.
size
(
2
);
torch
::
Tensor
tmp
=
torch
::
zeros_like
(
layout
);
auto
_tmp
=
tmp
.
accessor
<
int
,
3
>
();
auto
_layout
=
layout
.
accessor
<
int
,
3
>
();
auto
_idx
=
idx
.
accessor
<
int
,
3
>
();
auto
_scratch
=
scratch
.
accessor
<
int
,
3
>
();
std
::
vector
<
int
>
current
(
H
,
0
);
#ifdef _OPENMP
#pragma omp parallel for
#endif
for
(
size_t
h
=
0
;
h
<
H
;
h
++
)
{
// surrounding indices
std
::
vector
<
int
>
ii_left
(
max_width
,
-
1
);
std
::
vector
<
std
::
vector
<
int
>>
ii_top
(
max_width
,
std
::
vector
<
int
>
(
N
,
-
1
));
for
(
size_t
m
=
0
;
m
<
M
;
m
++
)
{
for
(
size_t
n
=
0
;
n
<
N
;
n
++
)
{
int
v
=
_layout
[
h
][
m
][
n
];
if
(
v
==
0
)
continue
;
int
n_left
=
ii_left
[
max_width
-
1
];
int
m_top
=
ii_top
[
max_width
-
1
][
n
];
int
top
=
(
m_top
>=
0
)
?
_tmp
[
h
][
m_top
][
n
]
:
0
;
int
left
=
(
n_left
>=
0
)
?
_tmp
[
h
][
m
][
n_left
]
:
0
;
int
topleft
=
(
m_top
>=
0
&&
n_left
>=
0
)
?
_tmp
[
h
][
m_top
][
n_left
]
:
0
;
int
width
=
std
::
min
(
left
,
std
::
min
(
top
,
topleft
))
+
1
;
// reset width if blocks cannot be
// packed together (i.e., there's a 1 "in the middle")
for
(
int
nn
=
n_left
+
1
;
nn
<
n
;
nn
++
)
if
(
ii_top
[
max_width
-
1
][
nn
]
>
ii_top
[
max_width
-
1
][
n
])
width
=
1
;
_tmp
[
h
][
m
][
n
]
=
width
;
// update n_left ring buffer
for
(
int
k
=
0
;
k
<
max_width
-
1
;
k
++
)
ii_left
[
k
]
=
ii_left
[
k
+
1
];
ii_left
[
max_width
-
1
]
=
n
;
// update ii_top ring buffer
for
(
int
k
=
0
;
k
<
max_width
-
1
;
k
++
)
ii_top
[
k
][
n
]
=
ii_top
[
k
+
1
][
n
];
ii_top
[
max_width
-
1
][
n
]
=
m
;
// block is too small -- skip
if
(
width
!=
max_width
)
continue
;
// retained blocks are set to zeros
for
(
size_t
km
=
0
;
km
<
max_width
;
km
++
)
for
(
size_t
kn
=
0
;
kn
<
max_width
;
kn
++
)
{
int
mm
=
ii_top
[
km
][
n
];
int
nn
=
ii_left
[
kn
];
if
(
mm
<
0
||
nn
<
0
)
continue
;
_layout
[
h
][
mm
][
nn
]
=
0
;
_tmp
[
h
][
mm
][
nn
]
=
0
;
_scratch
[
h
][
current
[
h
]][
0
]
=
(
int
)
h
;
_scratch
[
h
][
current
[
h
]][
1
]
=
(
int
)
mm
;
_scratch
[
h
][
current
[
h
]][
2
]
=
(
int
)
nn
;
_scratch
[
h
][
current
[
h
]][
3
]
=
_idx
[
h
][
mm
][
nn
];
current
[
h
]
++
;
}
}
}
}
std
::
vector
<
torch
::
Tensor
>
to_cat
;
for
(
size_t
h
=
0
;
h
<
H
;
h
++
)
if
(
current
[
h
]
>
0
)
to_cat
.
push_back
(
scratch
[
h
].
slice
(
0
,
0
,
current
[
h
]));
if
(
!
to_cat
.
empty
())
ret
.
push_back
({
max_width
,
torch
::
cat
(
to_cat
)});
}
ret_t
sdd_segment
(
torch
::
Tensor
layout
,
int
start_width
)
{
ret_t
ret
;
// block index
torch
::
Tensor
idx
=
torch
::
zeros_like
(
layout
);
int
current
=
0
;
size_t
H
=
layout
.
size
(
0
);
size_t
M
=
layout
.
size
(
1
);
size_t
N
=
layout
.
size
(
2
);
auto
_layout
=
layout
.
accessor
<
int
,
3
>
();
auto
_idx
=
idx
.
accessor
<
int
,
3
>
();
for
(
size_t
h
=
0
;
h
<
H
;
h
++
)
for
(
size_t
m
=
0
;
m
<
M
;
m
++
)
for
(
size_t
n
=
0
;
n
<
N
;
n
++
)
{
if
(
_layout
[
h
][
m
][
n
]
==
0
)
continue
;
_idx
[
h
][
m
][
n
]
=
current
++
;
}
// scratch memory
//torch::Tensor scratch = torch::empty({H, layout.sum().item<int>(), 4}, layout.dtype());
//aiss debug
torch
::
Tensor
scratch
=
torch
::
empty
({(
long
)
H
,
layout
.
sum
().
item
<
int
>
(),
4
},
layout
.
dtype
());
for
(
int
max_width
=
start_width
;
max_width
>
0
;
max_width
/=
2
)
segment_blocks
(
layout
,
idx
,
scratch
,
max_width
,
ret
);
return
ret
;
}
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"sdd_segment"
,
&
sdd_segment
,
"SDD segmentation handler"
);
}
csrc/transformer/hip/cublas_wrappers.hip
0 → 100644
View file @
eadbbe09
#include "hip/cublas_wrappers.h"
int cublas_gemm_ex(rocblas_handle handle,
rocblas_operation transa,
rocblas_operation transb,
int m,
int n,
int k,
const float* alpha,
const float* beta,
const float* A,
const float* B,
float* C,
//cublasGemmAlgo_t algo)
rocblas_gemm_algo algo)
{
rocblas_status status = rocblas_gemm_ex(handle,
transa,
transb,
m,
n,
k,
(const void*)alpha,
(const void*)A,
rocblas_datatype_f32_r,
(transa == rocblas_operation_none) ? m : k,
(const void*)B,
rocblas_datatype_f32_r,
(transb == rocblas_operation_none) ? k : n,
(const void*)beta,
C,
rocblas_datatype_f32_r,
m,
C,
rocblas_datatype_f32_r,
m,
rocblas_datatype_f32_r,
algo,
0,
0);
if (status != rocblas_status_success) {
fprintf(stderr,
"!!!! kernel execution error. (m: %d, n: %d, k: %d, error: %d) \n",
m,
n,
k,
(int)status);
return EXIT_FAILURE;
}
return 0;
}
int cublas_gemm_ex(rocblas_handle handle,
rocblas_operation transa,
rocblas_operation transb,
int m,
int n,
int k,
const float* alpha,
const float* beta,
const __half* A,
const __half* B,
__half* C,
//cublasGemmAlgo_t algo)
rocblas_gemm_algo algo)
{
rocblas_status status = rocblas_gemm_ex(handle,
transa,
transb,
m,
n,
k,
(const void*)alpha,
(const void*)A,
rocblas_datatype_f16_r,
(transa == rocblas_operation_none) ? m : k,
(const void*)B,
rocblas_datatype_f16_r,
(transb == rocblas_operation_none) ? k : n,
(const void*)beta,
(void*)C,
rocblas_datatype_f16_r,
m,
(void*)C,
rocblas_datatype_f16_r,
m,
rocblas_datatype_f16_r,
algo,
0,
0);
if (status != rocblas_status_success) {
fprintf(stderr,
"!!!! kernel execution error. (m: %d, n: %d, k: %d, error: %d) \n",
m,
n,
k,
(int)status);
return EXIT_FAILURE;
}
return 0;
}
int cublas_strided_batched_gemm(rocblas_handle handle,
int m,
int n,
int k,
const float* alpha,
const float* beta,
const float* A,
const float* B,
float* C,
rocblas_operation op_A,
rocblas_operation op_B,
int stride_A,
int stride_B,
int stride_C,
int batch,
rocblas_gemm_algo algo)
{
rocblas_status status = rocblas_gemm_strided_batched_ex(handle,
op_A,
op_B,
m,
n,
k,
alpha,
A,
rocblas_datatype_f32_r,
(op_A == rocblas_operation_none) ? m : k,
stride_A,
B,
rocblas_datatype_f32_r,
(op_B == rocblas_operation_none) ? k : n,
stride_B,
beta,
C,
rocblas_datatype_f32_r,
m,
stride_C,
C,
rocblas_datatype_f32_r,
m,
stride_C,
batch,
rocblas_datatype_f32_r,
algo,
0,
0);
if (status != rocblas_status_success) {
fprintf(stderr,
"!!!! kernel execution error. (batch: %d, m: %d, n: %d, k: %d, error: %d) \n",
batch,
m,
n,
k,
(int)status);
return EXIT_FAILURE;
}
return 0;
}
int cublas_strided_batched_gemm(rocblas_handle handle,
int m,
int n,
int k,
const float* alpha,
const float* beta,
const __half* A,
const __half* B,
__half* C,
rocblas_operation op_A,
rocblas_operation op_B,
int stride_A,
int stride_B,
int stride_C,
int batch,
rocblas_gemm_algo algo)
{
rocblas_status status = rocblas_gemm_strided_batched_ex(handle,
op_A,
op_B,
m,
n,
k,
alpha,
A,
rocblas_datatype_f16_r,
(op_A == rocblas_operation_none) ? m : k,
stride_A,
B,
rocblas_datatype_f16_r,
(op_B == rocblas_operation_none) ? k : n,
stride_B,
beta,
C,
rocblas_datatype_f16_r,
m,
stride_C,
C,
rocblas_datatype_f16_r,
m,
stride_C,
batch,
rocblas_datatype_f16_r,
algo,
0,
0);
if (status != rocblas_status_success) {
fprintf(stderr,
"!!!! kernel execution error. (m: %d, n: %d, k: %d, error: %d) \n",
m,
n,
k,
(int)status);
return EXIT_FAILURE;
}
return 0;
}
csrc/transformer/hip/cublas_wrappers.hip.bak
0 → 100644
View file @
eadbbe09
#include "cublas_wrappers.h"
int cublas_gemm_ex(rocblas_handle handle,
rocblas_operation transa,
rocblas_operation transb,
int m,
int n,
int k,
const float* alpha,
const float* beta,
const float* A,
const float* B,
float* C,
cublasGemmAlgo_t algo)
{
rocblas_status status = rocblas_gemmex(handle,
transa,
transb,
m,
n,
k,
(const void*)alpha,
(const void*)A,
hipR32F,
(transa == rocblas_operation_none) ? m : k,
(const void*)B,
hipR32F,
(transb == rocblas_operation_none) ? k : n,
(const void*)beta,
C,
hipR32F,
m,
hipR32F,
algo);
if (status != rocblas_status_success) {
fprintf(stderr,
"!!!! kernel execution error. (m: %d, n: %d, k: %d, error: %d) \n",
m,
n,
k,
(int)status);
return EXIT_FAILURE;
}
return 0;
}
int cublas_gemm_ex(rocblas_handle handle,
rocblas_operation transa,
rocblas_operation transb,
int m,
int n,
int k,
const float* alpha,
const float* beta,
const __half* A,
const __half* B,
__half* C,
cublasGemmAlgo_t algo)
{
rocblas_status status = rocblas_gemmex(handle,
transa,
transb,
m,
n,
k,
(const void*)alpha,
(const void*)A,
hipR16F,
(transa == rocblas_operation_none) ? m : k,
(const void*)B,
hipR16F,
(transb == rocblas_operation_none) ? k : n,
(const void*)beta,
(void*)C,
hipR16F,
m,
hipR32F,
algo);
if (status != rocblas_status_success) {
fprintf(stderr,
"!!!! kernel execution error. (m: %d, n: %d, k: %d, error: %d) \n",
m,
n,
k,
(int)status);
return EXIT_FAILURE;
}
return 0;
}
int cublas_strided_batched_gemm(rocblas_handle handle,
int m,
int n,
int k,
const float* alpha,
const float* beta,
const float* A,
const float* B,
float* C,
rocblas_operation op_A,
rocblas_operation op_B,
int stride_A,
int stride_B,
int stride_C,
int batch,
cublasGemmAlgo_t algo)
{
rocblas_status status = cublasGemmStridedBatchedEx(handle,
op_A,
op_B,
m,
n,
k,
alpha,
A,
hipR32F,
(op_A == rocblas_operation_none) ? m : k,
stride_A,
B,
hipR32F,
(op_B == rocblas_operation_none) ? k : n,
stride_B,
beta,
C,
hipR32F,
m,
stride_C,
batch,
hipR32F,
algo);
if (status != rocblas_status_success) {
fprintf(stderr,
"!!!! kernel execution error. (batch: %d, m: %d, n: %d, k: %d, error: %d) \n",
batch,
m,
n,
k,
(int)status);
return EXIT_FAILURE;
}
return 0;
}
int cublas_strided_batched_gemm(rocblas_handle handle,
int m,
int n,
int k,
const float* alpha,
const float* beta,
const __half* A,
const __half* B,
__half* C,
rocblas_operation op_A,
rocblas_operation op_B,
int stride_A,
int stride_B,
int stride_C,
int batch,
cublasGemmAlgo_t algo)
{
rocblas_status status = cublasGemmStridedBatchedEx(handle,
op_A,
op_B,
m,
n,
k,
alpha,
A,
hipR16F,
(op_A == rocblas_operation_none) ? m : k,
stride_A,
B,
hipR16F,
(op_B == rocblas_operation_none) ? k : n,
stride_B,
beta,
C,
hipR16F,
m,
stride_C,
batch,
hipR32F,
algo);
if (status != rocblas_status_success) {
fprintf(stderr,
"!!!! kernel execution error. (m: %d, n: %d, k: %d, error: %d) \n",
m,
n,
k,
(int)status);
return EXIT_FAILURE;
}
return 0;
}
csrc/transformer/hip/dropout_kernels.hip
0 → 100644
View file @
eadbbe09
#include "hip/hip_runtime.h"
#include "hip/custom_hip_layers.h"
const int unroll_factor = 4;
__global__ void dropout_kernel(const int N,
const float ratio,
float* out,
const float* Xdata,
uint8_t* mask,
std::pair<uint64_t, uint64_t> seed)
{
const float scale = 1. / (1. - ratio);
int idx = blockIdx.x * blockDim.x + threadIdx.x;
hiprandStatePhilox4_32_10_t state;
hiprand_init(seed.first, idx, seed.second, &state);
CUDA_1D_KERNEL_LOOP(j, N / unroll_factor)
{
float4 rand = hiprand_uniform4(&state);
uint8_t m[unroll_factor];
m[0] = (uint8_t)(rand.x > ratio);
m[1] = (uint8_t)(rand.y > ratio);
m[2] = (uint8_t)(rand.z > ratio);
m[3] = (uint8_t)(rand.w > ratio);
int i = j * unroll_factor;
mask[i] = (uint8_t)m[0];
mask[i + 1] = (uint8_t)m[1];
mask[i + 2] = (uint8_t)m[2];
mask[i + 3] = (uint8_t)m[3];
out[i] = Xdata[i] * scale * m[0];
out[i + 1] = Xdata[i + 1] * scale * m[1];
out[i + 2] = Xdata[i + 2] * scale * m[2];
out[i + 3] = Xdata[i + 3] * scale * m[3];
}
int high_index =
((((N / unroll_factor) - 1) / blockDim.x + 1) * (unroll_factor * blockDim.x)) + threadIdx.x;
if (N > high_index) {
float4 rand = hiprand_uniform4(&state);
float* rand_data = &(rand.x);
int k = 0;
for (int i = high_index; i < N; i++) {
uint8_t m = (uint8_t)(rand_data[k++] > ratio);
out[i] = Xdata[i] * scale * m;
mask[i] = m;
}
}
}
__global__ void dropout_kernel(const int N,
const float ratio,
__half* out,
const __half* Xdata,
uint8_t* mask,
std::pair<uint64_t, uint64_t> seed)
{
const float scale = 1. / (1. - ratio);
int idx = blockIdx.x * blockDim.x + threadIdx.x;
hiprandStatePhilox4_32_10_t state;
hiprand_init(seed.first, idx, seed.second, &state);
#ifdef __STOCHASTIC_MODE__
const __half2 h_scale = __float2half2_rn(scale);
const float2* x_cast = reinterpret_cast<const float2*>(Xdata);
float2* out_cast = reinterpret_cast<float2*>(out);
uint32_t* mask_cast = reinterpret_cast<uint32_t*>(mask);
uint32_t m_32;
uint8_t* m = reinterpret_cast<uint8_t*>(&m_32);
float2 result_f;
__half2* result_h = reinterpret_cast<__half2*>(&result_f);
__half2 mask_h[2];
float2 mask_f[2];
CUDA_1D_KERNEL_LOOP(j, N / unroll_factor)
{
float2 x_f = x_cast[j];
__half2* x_h = reinterpret_cast<__half2*>(&x_f);
float4 rand = hiprand_uniform4(&state);
m[0] = (uint8_t)(rand.x > ratio);
m[1] = (uint8_t)(rand.y > ratio);
m[2] = (uint8_t)(rand.z > ratio);
m[3] = (uint8_t)(rand.w > ratio);
float* mask_f_data = &mask_f[0].x;
#pragma unroll
for (int i = 0; i < unroll_factor; i++) mask_f_data[i] = (float)(m[i]);
mask_h[0] = __float22half2_rn(mask_f[0]);
mask_h[1] = __float22half2_rn(mask_f[1]);
result_h[0] = x_h[0] * h_scale * mask_h[0];
result_h[1] = x_h[1] * h_scale * mask_h[1];
out_cast[j] = result_f;
mask_cast[j] = m_32;
}
#else
CUDA_1D_KERNEL_LOOP(j, N / unroll_factor)
{
int i = j * unroll_factor;
const __half2* vals_half = reinterpret_cast<const __half2*>(Xdata + i);
float2 vals_half_f[2];
vals_half_f[0] = __half22float2(vals_half[0]);
vals_half_f[1] = __half22float2(vals_half[1]);
uint8_t m[unroll_factor];
float4 rand = hiprand_uniform4(&state);
m[0] = (uint8_t)(rand.x > ratio);
m[1] = (uint8_t)(rand.y > ratio);
m[2] = (uint8_t)(rand.z > ratio);
m[3] = (uint8_t)(rand.w > ratio);
out[i] = __float2half(vals_half_f[0].x * scale * m[0]);
out[i + 1] = __float2half(vals_half_f[0].y * scale * m[1]);
out[i + 2] = __float2half(vals_half_f[1].x * scale * m[2]);
out[i + 3] = __float2half(vals_half_f[1].y * scale * m[3]);
mask[i] = m[0];
mask[i + 1] = m[1];
mask[i + 2] = m[2];
mask[i + 3] = m[3];
}
#endif
int high_index =
((((N / unroll_factor) - 1) / blockDim.x + 1) * (unroll_factor * blockDim.x)) + threadIdx.x;
if (N > high_index) {
float4 rand = hiprand_uniform4(&state);
float* rand_data = &(rand.x);
int k = 0;
for (int i = high_index; i < N; i++) {
uint8_t m = (uint8_t)(rand_data[k++] > ratio);
out[i] = __float2half((float)Xdata[i] * scale * m);
mask[i] = m;
}
}
}
__global__ void dropout_kernel_bwd(const int N,
const float ratio,
const float* Xdata,
float* out,
uint8_t* mask,
std::pair<uint64_t, uint64_t> seed)
{
const float scale = 1. / (1. - ratio);
CUDA_1D_KERNEL_LOOP(j, N / unroll_factor)
{
int i = j * unroll_factor;
out[i] = mask[i] ? Xdata[i] * scale : 0.0;
out[i + 1] = mask[i + 1] ? Xdata[i + 1] * scale : 0.0;
out[i + 2] = mask[i + 2] ? Xdata[i + 2] * scale : 0.0;
out[i + 3] = mask[i + 3] ? Xdata[i + 3] * scale : 0.0;
}
int high_index =
((((N / unroll_factor) - 1) / blockDim.x + 1) * (unroll_factor * blockDim.x)) + threadIdx.x;
if (N > high_index) {
for (int i = high_index; i < N; i++) { out[i] = mask[i] ? Xdata[i] * scale : 0.0; }
}
}
__global__ void dropout_kernel_bwd(const int N,
const float ratio,
const __half* Xdata,
__half* out,
uint8_t* mask,
std::pair<uint64_t, uint64_t> seed)
{
const float scale = 1. / (1. - ratio);
#ifdef __STOCHASTIC_MODE__
const __half2 h_scale = __float2half2_rn(scale);
const float2* x_cast = reinterpret_cast<const float2*>(Xdata);
float2* out_cast = reinterpret_cast<float2*>(out);
uint32_t* mask_cast = reinterpret_cast<uint32_t*>(mask);
CUDA_1D_KERNEL_LOOP(j, N / unroll_factor)
{
float2 x_f = x_cast[j];
__half2* x_h = reinterpret_cast<__half2*>(&x_f);
uint32_t m_32 = mask_cast[j];
uint8_t* m = (uint8_t*)&m_32;
__half2 mask_h[2];
float2 mask_f[2];
float* mask_f_data = &mask_f[0].x;
#pragma unroll
for (int i = 0; i < unroll_factor; i++) mask_f_data[i] = (float)(m[i]);
#pragma unroll
for (int i = 0; i < 2; i++) mask_h[i] = __float22half2_rn(mask_f[i]);
float2 result_f;
__half2* result_h = reinterpret_cast<__half2*>(&result_f);
result_h[0] = x_h[0] * h_scale * mask_h[0];
result_h[1] = x_h[1] * h_scale * mask_h[1];
out_cast[j] = result_f;
}
#else
const __half h_scale = __float2half(scale);
const __half h_zero = __float2half(0.0);
CUDA_1D_KERNEL_LOOP(j, N / unroll_factor)
{
int i = j * unroll_factor;
const __half2* vals_half = reinterpret_cast<const __half2*>(Xdata + i);
uint8_t* m = mask + i;
float2 vals_half_f[2];
vals_half_f[0] = __half22float2(vals_half[0]);
vals_half_f[1] = __half22float2(vals_half[1]);
out[i] = __float2half(vals_half_f[0].x * scale * m[0]);
out[i + 1] = __float2half(vals_half_f[0].y * scale * m[1]);
out[i + 2] = __float2half(vals_half_f[1].x * scale * m[2]);
out[i + 3] = __float2half(vals_half_f[1].y * scale * m[3]);
}
#endif
int high_index =
((((N / unroll_factor) - 1) / blockDim.x + 1) * (unroll_factor * blockDim.x)) + threadIdx.x;
if (N > high_index) {
for (int i = high_index; i < N; i++) {
out[i] = __float2half((float)Xdata[i] * scale * mask[i]);
}
}
}
template <typename T>
void launch_dropout(T* out,
const T* vals,
uint8_t* mask,
int total_count,
int dim,
float ratio,
hipStream_t stream,
bool bwd)
{
assert(unroll_factor == 4);
dim3 grid_dim = DS_GET_BLOCKS(total_count / unroll_factor);
dim3 block_dim = DS_CUDA_NUM_THREADS;
if (dim > 512) {
block_dim.x >>= 1;
grid_dim.x <<= 1;
}
uint64_t inc = total_count / grid_dim.x / block_dim.x;
std::pair<uint64_t, uint64_t> seed = Context::Instance().IncrementOffset(inc);
if (bwd)
hipLaunchKernelGGL(( dropout_kernel_bwd), dim3(grid_dim), dim3(block_dim), 0, stream,
total_count, ratio, vals, out, mask, seed);
else
hipLaunchKernelGGL(( dropout_kernel), dim3(grid_dim), dim3(block_dim), 0, stream,
total_count, ratio, out, vals, mask, seed);
}
template void launch_dropout(float* out,
const float* vals,
uint8_t* mask,
int total_count,
int dim,
float ratio,
hipStream_t stream,
bool);
template void launch_dropout(__half* out,
const __half* vals,
uint8_t* mask,
int total_count,
int dim,
float ratio,
hipStream_t stream,
bool);
__global__ void dropout_grad_kernel(const int N, const float scale, float* Xdata, uint8_t* mask)
{
CUDA_1D_KERNEL_LOOP(i, N) { Xdata[i] *= scale * mask[i]; }
}
__global__ void dropout_grad_kernel(const int N, const float scale, __half* Xdata, uint8_t* mask)
{
const __half2 h_scale = __float2half2_rn(scale);
float2* x_cast = reinterpret_cast<float2*>(Xdata);
uint32_t* mask_cast = reinterpret_cast<uint32_t*>(mask);
CUDA_1D_KERNEL_LOOP(j, N / unroll_factor)
{
float2 x_data = x_cast[j];
uint32_t m_32 = mask_cast[j];
uint8_t* m = (uint8_t*)&m_32;
float2 result_f;
__half2* result_h = reinterpret_cast<__half2*>(&result_f);
#ifdef __STOCHASTIC_MODE__
__half2* x_data_h = reinterpret_cast<__half2*>(&x_data);
__half2 mask_h[2];
float2 mask_f[2];
float* mask_f_data = &mask_f[0].x;
#pragma unroll
for (int i = 0; i < unroll_factor; i++) *(mask_f_data++) = (float)(m[i]);
mask_h[0] = __float22half2_rn(mask_f[0]);
mask_h[1] = __float22half2_rn(mask_f[1]);
result_h[0] = x_data_h[0] * h_scale * mask_h[0];
result_h[1] = x_data_h[1] * h_scale * mask_h[1];
#else
__half* x_data_h = reinterpret_cast<__half*>(&x_data);
float2 result[2];
result[0].x = (float)x_data_h[0] * scale * m[0];
result[0].y = (float)x_data_h[1] * scale * m[1];
result[1].x = (float)x_data_h[2] * scale * m[2];
result[1].y = (float)x_data_h[3] * scale * m[3];
result_h[0] = __float22half2_rn(result[0]);
result_h[1] = __float22half2_rn(result[1]);
#endif
x_cast[j] = result_f;
}
int high_index =
((((N / unroll_factor) - 1) / blockDim.x + 1) * (unroll_factor * blockDim.x)) + threadIdx.x;
if (N > high_index) {
for (int i = high_index; i < N; i++) {
Xdata[i] = __float2half((float)Xdata[i] * scale * mask[i]);
}
}
}
template <typename T>
void launch_dropout_grad(T* vals, uint8_t* mask, int total_count, float ratio, hipStream_t stream)
{
assert(unroll_factor == 4);
const float scale = 1. / (1. - ratio);
hipLaunchKernelGGL(( dropout_grad_kernel), dim3(DS_GET_BLOCKS(total_count / unroll_factor)),
dim3(DS_CUDA_NUM_THREADS),
0,
stream, total_count, scale, vals, mask);
}
template void launch_dropout_grad(float* vals,
uint8_t* mask,
int total_count,
float ratio,
hipStream_t stream);
template void launch_dropout_grad(__half* vals,
uint8_t* mask,
int total_count,
float ratio,
hipStream_t stream);
__global__ void dropout_grad_kernel(const int N,
const float scale,
const float* Xdata,
float* out,
uint8_t* mask)
{
CUDA_1D_KERNEL_LOOP(i, N) { out[i] = Xdata[i] * scale * mask[i]; }
}
__global__ void dropout_grad_kernel(const int N,
const float scale,
const __half* Xdata,
__half* out,
uint8_t* mask)
{
const float2* x_cast = reinterpret_cast<const float2*>(Xdata);
float2* out_cast = reinterpret_cast<float2*>(out);
const uint32_t* mask_cast = reinterpret_cast<const uint32_t*>(mask);
float2 result_f;
__half2* result_h = reinterpret_cast<__half2*>(&result_f);
CUDA_1D_KERNEL_LOOP(j, N / unroll_factor)
{
float2 x_data = x_cast[j];
uint32_t m_32 = mask_cast[j];
uint8_t* m = (uint8_t*)&m_32;
__half* x_data_h = reinterpret_cast<__half*>(&x_data);
float2 result[2];
result[0].x = (float)x_data_h[0] * scale * m[0];
result[0].y = (float)x_data_h[1] * scale * m[1];
result[1].x = (float)x_data_h[2] * scale * m[2];
result[1].y = (float)x_data_h[3] * scale * m[3];
result_h[0] = __float22half2_rn(result[0]);
result_h[1] = __float22half2_rn(result[1]);
out_cast[j] = result_f;
}
int high_index =
((((N / unroll_factor) - 1) / blockDim.x + 1) * (unroll_factor * blockDim.x)) + threadIdx.x;
if (N > high_index) {
for (int i = high_index; i < N; i++) {
out[i] = __float2half((float)Xdata[i] * scale * mask[i]);
}
}
}
template <typename T>
void launch_dropout_grad(T* vals_out,
const T* vals,
uint8_t* mask,
int total_count,
float ratio,
hipStream_t stream)
{
assert(unroll_factor == 4);
const float scale = 1. / (1. - ratio);
hipLaunchKernelGGL(( dropout_grad_kernel), dim3(DS_GET_BLOCKS(total_count / unroll_factor)),
dim3(DS_CUDA_NUM_THREADS),
0,
stream, total_count, scale, vals, vals_out, mask);
}
template void launch_dropout_grad(float*,
const float* vals,
uint8_t* mask,
int total_count,
float ratio,
hipStream_t stream);
template void launch_dropout_grad(__half*,
const __half* vals,
uint8_t* mask,
int total_count,
float ratio,
hipStream_t stream);
__global__ void dropout_kernel(const int N,
const int dim,
const float ratio,
const float* bias,
float* Xdata,
uint8_t* mask,
std::pair<uint64_t, uint64_t> seed)
{
const float scale = 1. / (1. - ratio);
int idx = blockIdx.x * blockDim.x + threadIdx.x;
int tid = threadIdx.x % (dim / unroll_factor);
hiprandStatePhilox4_32_10_t state;
hiprand_init(seed.first, idx, seed.second, &state);
float4* Xdata_cast = reinterpret_cast<float4*>(Xdata);
uint32_t* mask_32 = reinterpret_cast<uint32_t*>(mask);
const float4* bias_cast = reinterpret_cast<const float4*>(bias);
CUDA_1D_KERNEL_LOOP(j, N)
{
float4 rand = hiprand_uniform4(&state);
uint32_t m_32;
uint8_t* m = (uint8_t*)&m_32;
m[0] = (uint8_t)(rand.x > ratio);
m[1] = (uint8_t)(rand.y > ratio);
m[2] = (uint8_t)(rand.z > ratio);
m[3] = (uint8_t)(rand.w > ratio);
float4 x_data = Xdata_cast[j];
float4 b_data = bias_cast[j % (dim / unroll_factor)];
x_data.x += b_data.x;
x_data.y += b_data.y;
x_data.z += b_data.z;
x_data.w += b_data.w;
x_data.x = x_data.x * scale * m[0];
x_data.y = x_data.y * scale * m[1];
x_data.z = x_data.z * scale * m[2];
x_data.w = x_data.w * scale * m[3];
mask_32[j] = m_32;
Xdata_cast[j] = x_data;
}
int high_index =
((((N / unroll_factor) - 1) / blockDim.x + 1) * (unroll_factor * blockDim.x)) + threadIdx.x;
if (N > high_index) {
float4 rand = hiprand_uniform4(&state);
float* rand_data = &(rand.x);
int k = 0;
for (int i = high_index; i < N; i++) {
float x_data = Xdata[i] + bias[i % dim];
uint8_t m = (uint8_t)(rand_data[k++] > ratio);
Xdata[i] = x_data * scale * m;
mask[i] = m;
}
}
}
__global__ void dropout_kernel(const int N,
const int dim,
const float ratio,
const __half* bias,
__half* Xdata,
uint8_t* mask,
std::pair<uint64_t, uint64_t> seed)
{
const float scale = 1. / (1. - ratio);
int idx = blockIdx.x * blockDim.x + threadIdx.x;
int tid = threadIdx.x % (dim / unroll_factor);
hiprandStatePhilox4_32_10_t state;
hiprand_init(seed.first, idx, seed.second, &state);
float2* Xdata_cast = reinterpret_cast<float2*>(Xdata);
uint32_t* mask_32 = reinterpret_cast<uint32_t*>(mask);
const float2* bias_cast = reinterpret_cast<const float2*>(bias);
CUDA_1D_KERNEL_LOOP(j, N)
{
float4 rand = hiprand_uniform4(&state);
float2 data_f;
__half2* data_h = reinterpret_cast<__half2*>(&data_f);
float2 bias_f;
__half2* bias_h = reinterpret_cast<__half2*>(&bias_f);
data_f = Xdata_cast[j];
bias_f = bias_cast[j % (dim / unroll_factor)];
float2 data_h_0 = __half22float2(data_h[0]);
float2 data_h_1 = __half22float2(data_h[1]);
float2 bias_h_0 = __half22float2(bias_h[0]);
float2 bias_h_1 = __half22float2(bias_h[1]);
data_h_0.x += bias_h_0.x;
data_h_0.y += bias_h_0.y;
data_h_1.x += bias_h_1.x;
data_h_1.y += bias_h_1.y;
uint32_t m_32;
uint8_t* m = (uint8_t*)&m_32;
m[0] = (uint8_t)(rand.x > ratio);
m[1] = (uint8_t)(rand.y > ratio);
m[2] = (uint8_t)(rand.z > ratio);
m[3] = (uint8_t)(rand.w > ratio);
data_h_0.x = __float2half(data_h_0.x * scale * m[0]);
data_h_0.y = __float2half(data_h_0.y * scale * m[1]);
data_h_1.x = __float2half(data_h_1.x * scale * m[2]);
data_h_1.y = __float2half(data_h_1.y * scale * m[3]);
float2 result_f;
__half2* result_h = reinterpret_cast<__half2*>(&result_f);
result_h[0] = __float22half2_rn(data_h_0);
result_h[1] = __float22half2_rn(data_h_1);
Xdata_cast[j] = result_f;
mask_32[j] = m_32;
}
int high_index =
((((N / unroll_factor) - 1) / blockDim.x + 1) * (unroll_factor * blockDim.x)) + threadIdx.x;
if (N > high_index) {
float4 rand = hiprand_uniform4(&state);
float* rand_data = &(rand.x);
int k = 0;
for (int i = high_index; i < N; i++) {
float x_data = (float)Xdata[i] + (float)bias[i % dim];
uint8_t m = (uint8_t)(rand_data[k++] > ratio);
Xdata[i] = __float2half(x_data * scale * m);
mask[i] = m;
}
}
}
template <typename T>
void launch_dropout(T* out,
const T* bias,
uint8_t* mask,
int batch,
int dim,
float ratio,
hipStream_t stream)
{
assert(unroll_factor == 4);
int total_count = batch * dim / unroll_factor;
dim3 grid_dim = DS_GET_BLOCKS(total_count);
dim3 block_dim = DS_CUDA_NUM_THREADS;
uint64_t inc = (batch * dim) / grid_dim.x / block_dim.x;
std::pair<uint64_t, uint64_t> seed = Context::Instance().IncrementOffset(inc);
hipLaunchKernelGGL(( dropout_kernel), dim3(grid_dim), dim3(block_dim), 0, stream,
total_count, dim, ratio, bias, out, mask, seed);
}
template void launch_dropout(float*,
const float* bias,
uint8_t* mask,
int batch,
int dim,
float ratio,
hipStream_t stream);
template void launch_dropout(__half*,
const __half* bias,
uint8_t* mask,
int batch,
int dim,
float ratio,
hipStream_t stream);
__global__ void dropout_kernel(const int N,
const int dim,
const float ratio,
const float* input,
const float* residual,
const float* bias,
float* out,
uint8_t* mask,
std::pair<uint64_t, uint64_t> seed)
{
const float scale = 1. / (1. - ratio);
int idx = blockIdx.x * blockDim.x + threadIdx.x;
int tid = threadIdx.x % (dim / unroll_factor);
hiprandStatePhilox4_32_10_t state;
hiprand_init(seed.first, idx, seed.second, &state);
float4* out_cast = reinterpret_cast<float4*>(out);
uint32_t* mask_32 = reinterpret_cast<uint32_t*>(mask);
const float4* bias_cast = reinterpret_cast<const float4*>(bias);
const float4* residual_cast = reinterpret_cast<const float4*>(residual);
const float4* input_cast = reinterpret_cast<const float4*>(input);
CUDA_1D_KERNEL_LOOP(j, N)
{
float4 rand = hiprand_uniform4(&state);
uint32_t m_32;
uint8_t* m = (uint8_t*)&m_32;
m[0] = (uint8_t)(rand.x > ratio);
m[1] = (uint8_t)(rand.y > ratio);
m[2] = (uint8_t)(rand.z > ratio);
m[3] = (uint8_t)(rand.w > ratio);
float4 out_data;
float4 b_data = bias_cast[j % (dim / unroll_factor)];
float4 res_data = residual_cast[j];
float4 inp_data = input_cast[j];
out_data.x = (b_data.x + inp_data.x);
out_data.y = (b_data.y + inp_data.y);
out_data.z = (b_data.z + inp_data.z);
out_data.w = (b_data.w + inp_data.w);
out_data.x = out_data.x * scale * m[0];
out_data.y = out_data.y * scale * m[1];
out_data.z = out_data.z * scale * m[2];
out_data.w = out_data.w * scale * m[3];
out_data.x += res_data.x;
out_data.y += res_data.y;
out_data.z += res_data.z;
out_data.w += res_data.w;
mask_32[j] = m_32;
out_cast[j] = out_data;
}
int high_index =
((((N / unroll_factor) - 1) / blockDim.x + 1) * (unroll_factor * blockDim.x)) + threadIdx.x;
if (N > high_index) {
float4 rand = hiprand_uniform4(&state);
float* rand_data = &(rand.x);
int k = 0;
for (int i = high_index; i < N; i++) {
float x_data = input[i] + bias[i % dim];
uint8_t m = (uint8_t)(rand_data[k++] > ratio);
x_data = x_data * scale * m;
x_data += residual[i];
out[i] = x_data;
mask[i] = m;
}
}
}
__global__ void dropout_kernel(const int N,
const int dim,
const float ratio,
const __half* input,
const __half* residual,
const __half* bias,
__half* out,
uint8_t* mask,
std::pair<uint64_t, uint64_t> seed)
{
const float scale = 1. / (1. - ratio);
int idx = blockIdx.x * blockDim.x + threadIdx.x;
int tid = threadIdx.x % (dim / unroll_factor);
hiprandStatePhilox4_32_10_t state;
hiprand_init(seed.first, idx, seed.second, &state);
float2* out_cast = reinterpret_cast<float2*>(out);
uint32_t* mask_32 = reinterpret_cast<uint32_t*>(mask);
const float2* bias_cast = reinterpret_cast<const float2*>(bias);
const float2* residual_cast = reinterpret_cast<const float2*>(residual);
const float2* input_cast = reinterpret_cast<const float2*>(input);
CUDA_1D_KERNEL_LOOP(j, N)
{
float4 rand = hiprand_uniform4(&state);
float2 data_f;
__half2* data_h = reinterpret_cast<__half2*>(&data_f);
float2 bias_f;
__half2* bias_h = reinterpret_cast<__half2*>(&bias_f);
float2 residual_f;
__half2* residual_h = reinterpret_cast<__half2*>(&residual_f);
float2 input_f;
__half2* input_h = reinterpret_cast<__half2*>(&input_f);
bias_f = bias_cast[j % (dim / unroll_factor)];
residual_f = residual_cast[j];
input_f = input_cast[j];
float2 data_h_0 = __half22float2(data_h[0]);
float2 data_h_1 = __half22float2(data_h[1]);
float2 bias_h_0 = __half22float2(bias_h[0]);
float2 bias_h_1 = __half22float2(bias_h[1]);
float2 residual_h_0 = __half22float2(residual_h[0]);
float2 residual_h_1 = __half22float2(residual_h[1]);
float2 input_h_0 = __half22float2(input_h[0]);
float2 input_h_1 = __half22float2(input_h[1]);
data_h_0.x = (bias_h_0.x + input_h_0.x);
data_h_0.y = (bias_h_0.y + input_h_0.y);
data_h_1.x = (bias_h_1.x + input_h_1.x);
data_h_1.y = (bias_h_1.y + input_h_1.y);
uint32_t m_32;
uint8_t* m = (uint8_t*)&m_32;
m[0] = (uint8_t)(rand.x > ratio);
m[1] = (uint8_t)(rand.y > ratio);
m[2] = (uint8_t)(rand.z > ratio);
m[3] = (uint8_t)(rand.w > ratio);
data_h_0.x = __float2half(data_h_0.x * scale * m[0]);
data_h_0.y = __float2half(data_h_0.y * scale * m[1]);
data_h_1.x = __float2half(data_h_1.x * scale * m[2]);
data_h_1.y = __float2half(data_h_1.y * scale * m[3]);
data_h_0.x += residual_h_0.x;
data_h_0.y += residual_h_0.y;
data_h_1.x += residual_h_1.x;
data_h_1.y += residual_h_1.y;
float2 result_f;
__half2* result_h = reinterpret_cast<__half2*>(&result_f);
result_h[0] = __float22half2_rn(data_h_0);
result_h[1] = __float22half2_rn(data_h_1);
out_cast[j] = result_f;
mask_32[j] = m_32;
}
int high_index =
((((N / unroll_factor) - 1) / blockDim.x + 1) * (unroll_factor * blockDim.x)) + threadIdx.x;
if (N > high_index) {
float4 rand = hiprand_uniform4(&state);
float* rand_data = &(rand.x);
int k = 0;
for (int i = high_index; i < N; i++) {
float x_data = (float)input[i] + (float)bias[i % dim];
uint8_t m = (uint8_t)(rand_data[k++] > ratio);
x_data = x_data * scale * m;
x_data += (float)residual[i];
out[i] = __float2half(x_data);
mask[i] = m;
}
}
}
template <typename T>
void launch_dropout(T* out,
const T* input,
const T* residual,
const T* bias,
uint8_t* mask,
int batch,
int dim,
float ratio,
hipStream_t stream)
{
assert(unroll_factor == 4);
int total_count = batch * dim / unroll_factor;
dim3 grid_dim = DS_GET_BLOCKS(total_count);
dim3 block_dim = DS_CUDA_NUM_THREADS;
uint64_t inc = (batch * dim) / grid_dim.x / block_dim.x;
std::pair<uint64_t, uint64_t> seed = Context::Instance().IncrementOffset(inc);
hipLaunchKernelGGL(( dropout_kernel), dim3(grid_dim), dim3(block_dim), 0, stream,
total_count, dim, ratio, input, residual, bias, out, mask, seed);
}
template void launch_dropout(float*,
const float*,
const float* residual,
const float* bias,
uint8_t* mask,
int batch,
int dim,
float ratio,
hipStream_t stream);
template void launch_dropout(__half*,
const __half*,
const __half* residual,
const __half* bias,
uint8_t* mask,
int batch,
int dim,
float ratio,
hipStream_t stream);
csrc/transformer/hip/ds_transformer_hip.cpp
0 → 100644
View file @
eadbbe09
#include <torch/extension.h>
#include <rocblas.h>
#include <hip/hip_fp16.h>
#include <hip/hip_runtime.h>
#include <type_traits>
#include <unordered_map>
#include <vector>
#include "hip/Timer.h"
#include "hip/context.h"
#include "hip/cublas_wrappers.h"
#include "hip/custom_hip_layers.h"
#include "hip/ds_transformer_hip.h"
static
std
::
unordered_map
<
int
,
std
::
shared_ptr
<
void
>>
s_transformer_layers
;
const
int
init_seq_length
=
128
;
// C++ interface
template
<
typename
T
>
size_t
get_workspace_size
(
int
maxBatchSize
,
int
seq_len
,
int
hidden_size
,
int
intermediate_size
,
int
heads
,
bool
training
,
bool
gelu_checkpoint
)
{
size_t
workSpacesize
=
4
*
(
size_t
(
maxBatchSize
)
*
seq_len
*
hidden_size
);
if
(
training
)
{
workSpacesize
+=
((
std
::
max
)((
size_t
(
maxBatchSize
)
*
seq_len
*
intermediate_size
),
2
*
(
size_t
(
maxBatchSize
)
*
heads
*
seq_len
*
seq_len
)));
if
(
gelu_checkpoint
)
workSpacesize
+=
2
*
(
size_t
(
maxBatchSize
)
*
seq_len
*
intermediate_size
);
}
return
workSpacesize
;
// * sizeof(T);
}
// NOTE: AT_ASSERT has become AT_CHECK on master after 0.4.
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) \
CHECK_CUDA(x); \
CHECK_CONTIGUOUS(x)
template
<
typename
T
>
BertTransformerLayer
<
T
>::
BertTransformerLayer
(
int
layer_id
,
int
batch_size
,
int
hidden_size
,
int
num_heads
,
int
intermediate_size
,
int
seq_length
,
float
attn_prob_dropout_ratio
,
float
hidden_output_dropout_ratio
,
float
layer_norm_eps
,
bool
pre_or_postLayerNorm
,
const
std
::
vector
<
std
::
array
<
int
,
3
>>&
gemm_algos
,
bool
attn_dropout_checkpoint
,
bool
normalize_invertible
,
bool
gelu_checkpoint
,
bool
stochastic_mode
)
:
_layer_id
(
layer_id
),
_batch_size
(
batch_size
),
_hidden_size
(
hidden_size
),
_heads
(
num_heads
),
_intermediate_size
(
intermediate_size
),
_seq_length
(
seq_length
),
_training
(
true
),
_pre_or_postLayerNorm
(
pre_or_postLayerNorm
),
_attn_dropout_checkpoint
(
attn_dropout_checkpoint
),
_normalize_invertible
(
normalize_invertible
),
_gelu_checkpoint
(
gelu_checkpoint
),
_stochastic_mode
(
stochastic_mode
),
_stream
(
Context
::
Instance
().
GetCurrentStream
()),
_cublasHandle
(
Context
::
Instance
().
GetCublasHandle
()),
_qkv_linear
(
typename
FeedForward
<
T
>::
Config
(
batch_size
*
seq_length
,
3
*
hidden_size
,
hidden_size
,
gemm_algos
[
0
])),
_attn_out_linear
(
typename
FeedForward
<
T
>::
Config
(
batch_size
*
seq_length
,
hidden_size
,
hidden_size
,
gemm_algos
[
0
])),
_attn_layer_norm
(
typename
Normalize_Layer
<
T
>::
Config
(
batch_size
,
seq_length
,
hidden_size
,
layer_norm_eps
,
true
,
!
normalize_invertible
)),
_layer_norm
(
typename
Normalize_Layer
<
T
>::
Config
(
batch_size
,
seq_length
,
hidden_size
,
layer_norm_eps
,
true
,
!
normalize_invertible
)),
_ff1
(
typename
FeedForward
<
T
>::
Config
(
batch_size
*
seq_length
,
_intermediate_size
,
hidden_size
,
gemm_algos
[
1
])),
_ff2
(
typename
FeedForward
<
T
>::
Config
(
batch_size
*
seq_length
,
hidden_size
,
_intermediate_size
,
gemm_algos
[
2
])),
_softmax
(
typename
Softmax
<
T
>::
Config
(
batch_size
,
num_heads
,
seq_length
)),
_gelu
(
typename
Gelu
<
T
>::
Config
(
_intermediate_size
)),
_attn_prob_dropout
(
typename
Dropout
<
T
>::
Config
(
attn_prob_dropout_ratio
,
_seq_length
)),
_attn_output_dropout
(
typename
Dropout
<
T
>::
Config
(
hidden_output_dropout_ratio
,
_hidden_size
)),
_layer_output_dropout
(
typename
Dropout
<
T
>::
Config
(
hidden_output_dropout_ratio
,
_hidden_size
)),
_attn_scores
(
typename
StridedBatchGemm
<
T
>::
Config
(
_batch_size
*
_heads
,
_seq_length
,
_seq_length
,
_hidden_size
/
_heads
,
//(T(1.0) / T(sqrt(_hidden_size / _heads))),
(
T
(
1.0
/
(
sqrt
(
_hidden_size
/
_heads
)))),
T
(
0.0
),
rocblas_operation_transpose
,
rocblas_operation_none
,
gemm_algos
[
3
])),
_attn_context
(
typename
StridedBatchGemm
<
T
>::
Config
(
_batch_size
*
_heads
,
_hidden_size
/
_heads
,
_seq_length
,
_seq_length
,
T
(
1.0
),
T
(
0.0
),
rocblas_operation_none
,
rocblas_operation_none
,
gemm_algos
[
4
]))
{
assert
(
_hidden_size
%
_heads
==
0
);
Initialize
();
}
template
<
typename
T
>
BertTransformerLayer
<
T
>::~
BertTransformerLayer
()
{
}
template
<
typename
T
>
void
BertTransformerLayer
<
T
>::
Initialize
()
{
//aiss debug:rocm has no CUBLAS_TENSOR_OP_MATH
//if (std::is_same<T, __half>::value) rocblas_set_math_mode(_cublasHandle, CUBLAS_TENSOR_OP_MATH);
}
template
<
typename
T
>
void
BertTransformerLayer
<
T
>::
Forward
(
int
bsz
,
const
T
*
input_ptr
,
const
T
*
input_mask_ptr
,
const
T
*
attn_qkvw_ptr
,
const
T
*
attn_qkvb_ptr
,
const
T
*
attn_ow_ptr
,
const
T
*
attn_ob_ptr
,
const
T
*
attn_nw_ptr
,
const
T
*
attn_nb_ptr
,
const
T
*
inter_w_ptr
,
const
T
*
inter_b_ptr
,
const
T
*
output_w_ptr
,
const
T
*
output_b_ptr
,
const
T
*
norm_w_ptr
,
const
T
*
norm_b_ptr
,
T
*
out_ptr
,
T
*
inp_norm_ptr
,
T
*
q_tf_ptr
,
T
*
k_tf_ptr
,
T
*
v_tf_ptr
,
T
*
soft_out_ptr
,
T
*
ctx_bufB_ptr
,
T
*
attn_o_inp_ptr
,
T
*
add_res_ptr
,
T
*
ff1_inp_ptr
,
T
*
gelu_inp_ptr
,
T
*
ff2_inp_ptr
)
{
rocblas_set_stream
(
_cublasHandle
,
_stream
);
if
(
!
_stochastic_mode
)
hipStreamSynchronize
(
_stream
);
T
*
workspace
=
static_cast
<
T
*>
(
Context
::
Instance
().
GetWorkSpace
());
size_t
small_buf_size
=
bsz
*
_seq_length
*
_hidden_size
;
T
*
buf_0
=
workspace
;
T
*
buf_1
=
buf_0
+
small_buf_size
;
T
*
buf_2
=
buf_1
;
if
(
_normalize_invertible
)
{
add_res_ptr
=
buf_1
+
3
*
small_buf_size
;
buf_2
=
add_res_ptr
;
}
if
(
_gelu_checkpoint
)
buf_2
+=
small_buf_size
;
if
(
_attn_dropout_checkpoint
)
ctx_bufB_ptr
=
(
_gelu_checkpoint
?
(
buf_2
+
(
_intermediate_size
/
_hidden_size
)
*
small_buf_size
)
:
(
buf_1
+
4
*
small_buf_size
));
int
bsz_seq
=
bsz
*
_seq_length
;
if
(
_pre_or_postLayerNorm
)
{
if
(
_layer_norm
.
UseMean
())
_layer_norm
.
ForwardCheckpoint
(
bsz_seq
,
inp_norm_ptr
,
input_ptr
,
norm_w_ptr
,
norm_b_ptr
,
_stream
,
true
);
else
_layer_norm
.
Forward
(
bsz_seq
,
inp_norm_ptr
,
input_ptr
,
norm_w_ptr
,
norm_b_ptr
,
_stream
,
true
);
}
if
(
_pre_or_postLayerNorm
)
_qkv_linear
.
Forward
(
bsz_seq
,
inp_norm_ptr
,
attn_qkvw_ptr
,
buf_0
,
_cublasHandle
);
else
_qkv_linear
.
Forward
(
bsz_seq
,
input_ptr
,
attn_qkvw_ptr
,
buf_0
,
_cublasHandle
);
launch_bias_add_transform_0213
<
T
>
(
q_tf_ptr
,
buf_0
,
attn_qkvb_ptr
,
bsz
,
_seq_length
,
_hidden_size
,
_heads
,
_stream
,
3
);
int
bsz_heads
=
bsz
*
_heads
;
// attention scores
_attn_scores
.
Forward
(
bsz_heads
,
soft_out_ptr
,
k_tf_ptr
,
q_tf_ptr
,
_cublasHandle
);
// Softmax + Mask
_softmax
.
Forward
(
bsz
,
soft_out_ptr
,
input_mask_ptr
,
_stream
);
// attn prob dropout.
_attn_prob_dropout
.
Forward
(
bsz_heads
*
_seq_length
,
ctx_bufB_ptr
,
soft_out_ptr
,
_stream
);
// attention context
_attn_context
.
Forward
(
bsz_heads
,
buf_1
,
v_tf_ptr
,
ctx_bufB_ptr
,
_cublasHandle
);
launch_transform4d_0213
<
T
>
(
attn_o_inp_ptr
,
buf_1
,
bsz
,
_heads
,
_seq_length
,
_hidden_size
,
_stream
,
1
);
if
(
_pre_or_postLayerNorm
)
_attn_out_linear
.
Forward
(
bsz_seq
,
attn_o_inp_ptr
,
attn_ow_ptr
,
buf_1
,
_cublasHandle
);
else
_attn_out_linear
.
Forward
(
bsz_seq
,
attn_o_inp_ptr
,
attn_ow_ptr
,
ff1_inp_ptr
,
_cublasHandle
);
// attn output dropout.
if
(
_pre_or_postLayerNorm
)
_attn_output_dropout
.
ForwardWithBias
(
bsz_seq
,
add_res_ptr
,
buf_1
,
input_ptr
,
attn_ob_ptr
,
_stream
);
else
_attn_output_dropout
.
ForwardWithBias
(
bsz_seq
,
add_res_ptr
,
ff1_inp_ptr
,
input_ptr
,
attn_ob_ptr
,
_stream
);
if
(
_pre_or_postLayerNorm
)
{
if
(
_attn_layer_norm
.
UseMean
())
_attn_layer_norm
.
ForwardCheckpoint
(
bsz_seq
,
ff1_inp_ptr
,
add_res_ptr
,
attn_nw_ptr
,
attn_nb_ptr
,
_stream
,
true
);
else
_attn_layer_norm
.
Forward
(
bsz_seq
,
ff1_inp_ptr
,
add_res_ptr
,
attn_nw_ptr
,
attn_nb_ptr
,
_stream
,
true
);
}
else
{
if
(
_attn_layer_norm
.
UseMean
())
_attn_layer_norm
.
ForwardCheckpoint
(
bsz_seq
,
ff1_inp_ptr
,
add_res_ptr
,
attn_nw_ptr
,
attn_nb_ptr
,
_stream
,
true
);
else
_attn_layer_norm
.
Forward
(
bsz_seq
,
ff1_inp_ptr
,
add_res_ptr
,
attn_nw_ptr
,
attn_nb_ptr
,
_stream
,
true
);
}
_ff1
.
Forward
(
bsz_seq
,
ff1_inp_ptr
,
inter_w_ptr
,
(
_gelu_checkpoint
?
ff2_inp_ptr
:
gelu_inp_ptr
),
_cublasHandle
);
_gelu
.
ForwardWithBiasAdd
(
bsz_seq
,
(
_gelu_checkpoint
?
ff2_inp_ptr
:
gelu_inp_ptr
),
inter_b_ptr
,
(
_gelu_checkpoint
?
buf_2
:
ff2_inp_ptr
),
_stream
);
_ff2
.
Forward
(
bsz_seq
,
(
_gelu_checkpoint
?
buf_2
:
ff2_inp_ptr
),
output_w_ptr
,
out_ptr
,
_cublasHandle
);
// layer output dropout.
if
(
_pre_or_postLayerNorm
)
_layer_output_dropout
.
ForwardWithBias
(
bsz_seq
,
out_ptr
,
out_ptr
,
add_res_ptr
,
output_b_ptr
,
_stream
);
else
_layer_output_dropout
.
ForwardWithBias
(
bsz_seq
,
inp_norm_ptr
,
out_ptr
,
ff1_inp_ptr
,
output_b_ptr
,
_stream
);
if
(
!
_pre_or_postLayerNorm
)
{
if
(
_layer_norm
.
UseMean
())
_layer_norm
.
ForwardCheckpoint
(
bsz_seq
,
out_ptr
,
inp_norm_ptr
,
norm_w_ptr
,
norm_b_ptr
,
_stream
,
true
);
else
_layer_norm
.
Forward
(
bsz_seq
,
out_ptr
,
inp_norm_ptr
,
norm_w_ptr
,
norm_b_ptr
,
_stream
,
true
);
}
}
template
<
typename
T
>
void
BertTransformerLayer
<
T
>::
Backward
(
int
bsz
,
const
T
*
grad_output_ptr
,
const
T
*
input_ptr
,
const
T
*
output_ptr
,
const
T
*
inp_norm_ptr
,
const
T
*
q_tf_ptr
,
const
T
*
k_tf_ptr
,
const
T
*
v_tf_ptr
,
const
T
*
soft_out_ptr
,
const
T
*
ctx_bufB_ptr
,
const
T
*
attn_o_inp_ptr
,
const
T
*
add_res_ptr
,
const
T
*
ff1_inp_ptr
,
const
T
*
gelu_inp_ptr
,
const
T
*
ff2_inp_ptr
,
const
T
*
input_mask_ptr
,
const
T
*
attn_qkvw_ptr
,
const
T
*
attn_ow_ptr
,
const
T
*
attn_nw_ptr
,
const
T
*
attn_nb_ptr
,
const
T
*
inter_w_ptr
,
const
T
*
inter_b_ptr
,
const
T
*
output_w_ptr
,
const
T
*
norm_w_ptr
,
const
T
*
norm_b_ptr
,
T
*
grad_input_ptr
,
T
*
grad_attn_qkvw_ptr
,
T
*
grad_attn_qkvb_ptr
,
T
*
grad_attn_ow_ptr
,
T
*
grad_attn_ob_ptr
,
T
*
grad_attn_nw_ptr
,
T
*
grad_attn_nb_ptr
,
T
*
grad_inter_w_ptr
,
T
*
grad_inter_b_ptr
,
T
*
grad_output_w_ptr
,
T
*
grad_output_b_ptr
,
T
*
grad_norm_w_ptr
,
T
*
grad_norm_b_ptr
)
{
rocblas_set_stream
(
_cublasHandle
,
_stream
);
if
(
!
_stochastic_mode
)
hipStreamSynchronize
(
_stream
);
T
*
workspace
=
static_cast
<
T
*>
(
Context
::
Instance
().
GetWorkSpace
());
size_t
small_buf_size
=
bsz
*
_seq_length
*
_hidden_size
;
T
*
buf_0
=
workspace
;
T
*
buf_1
=
buf_0
+
small_buf_size
;
T
*
buf_2
=
buf_1
+
small_buf_size
;
T
*
buf_3
=
buf_2
+
small_buf_size
;
T
*
ff2_buf
=
(
_gelu_checkpoint
?
buf_3
+
(
bsz
*
_seq_length
*
_intermediate_size
)
:
buf_3
+
small_buf_size
);
T
*
ctx_bufB_ptr_recomp
=
ff2_buf
+
(
_seq_length
*
_seq_length
*
bsz
*
_heads
);
hipStream_t
streams
[
2
]
=
{
_stream
,
_stream
};
int
bsz_seq
=
bsz
*
_seq_length
;
int
bsz_heads
=
bsz
*
_heads
;
if
(
!
_pre_or_postLayerNorm
)
{
if
(
_layer_norm
.
UseMean
())
_layer_norm
.
Backward
(
bsz_seq
,
grad_output_ptr
,
norm_w_ptr
,
grad_norm_w_ptr
,
grad_norm_b_ptr
,
streams
,
buf_1
,
inp_norm_ptr
);
else
_layer_norm
.
Backward
(
bsz_seq
,
grad_output_ptr
,
norm_w_ptr
,
norm_b_ptr
,
grad_norm_w_ptr
,
grad_norm_b_ptr
,
streams
,
buf_1
,
output_ptr
);
}
if
(
_pre_or_postLayerNorm
)
_layer_output_dropout
.
Backward
(
bsz_seq
,
buf_0
,
grad_output_ptr
,
_stream
);
else
_layer_output_dropout
.
Backward
(
bsz_seq
,
buf_0
,
buf_1
,
_stream
);
const
T
*
layer_dropout_buf
=
_layer_output_dropout
.
HasDropout
()
?
buf_0
:
(
_pre_or_postLayerNorm
?
grad_output_ptr
:
buf_1
);
if
(
_gelu_checkpoint
)
_gelu
.
ForwardWithBiasAdd
(
bsz_seq
,
ff2_inp_ptr
,
inter_b_ptr
,
buf_2
,
_stream
);
_ff2
.
Backward
(
bsz_seq
,
layer_dropout_buf
,
(
_gelu_checkpoint
?
buf_2
:
ff2_inp_ptr
),
output_w_ptr
,
grad_output_w_ptr
,
grad_output_b_ptr
,
_cublasHandle
,
_stream
,
ff2_buf
);
_gelu
.
Backward
(
bsz_seq
,
ff2_buf
,
(
_gelu_checkpoint
?
ff2_inp_ptr
:
gelu_inp_ptr
),
inter_b_ptr
,
_stream
);
_ff1
.
Backward
(
bsz_seq
,
ff2_buf
,
ff1_inp_ptr
,
inter_w_ptr
,
grad_inter_w_ptr
,
grad_inter_b_ptr
,
_cublasHandle
,
_stream
,
buf_3
);
if
(
!
_pre_or_postLayerNorm
)
launch_fused_add2
<
T
>
(
buf_2
,
buf_3
,
buf_1
,
bsz
,
_seq_length
,
_hidden_size
,
_stream
);
if
(
_pre_or_postLayerNorm
)
{
if
(
_attn_layer_norm
.
UseMean
())
_attn_layer_norm
.
BackwardFusedAdd
(
bsz_seq
,
buf_3
,
grad_output_ptr
,
attn_nw_ptr
,
grad_attn_nw_ptr
,
grad_attn_nb_ptr
,
streams
,
buf_0
,
add_res_ptr
);
else
_attn_layer_norm
.
BackwardFusedAdd
(
bsz_seq
,
buf_3
,
grad_output_ptr
,
attn_nw_ptr
,
attn_nb_ptr
,
grad_attn_nw_ptr
,
grad_attn_nb_ptr
,
streams
,
buf_0
,
ff1_inp_ptr
);
}
else
{
if
(
_attn_layer_norm
.
UseMean
())
_attn_layer_norm
.
Backward
(
bsz_seq
,
buf_2
,
attn_nw_ptr
,
grad_attn_nw_ptr
,
grad_attn_nb_ptr
,
streams
,
buf_0
,
add_res_ptr
);
else
_attn_layer_norm
.
Backward
(
bsz_seq
,
buf_2
,
attn_nw_ptr
,
attn_nb_ptr
,
grad_attn_nw_ptr
,
grad_attn_nb_ptr
,
streams
,
buf_0
,
ff1_inp_ptr
);
}
_attn_output_dropout
.
Backward
(
bsz_seq
,
buf_2
,
buf_0
,
_stream
);
T
*
attn_output_dropout_buf
=
_attn_output_dropout
.
HasDropout
()
?
buf_2
:
buf_0
;
_attn_out_linear
.
Backward
(
bsz_seq
,
attn_output_dropout_buf
,
attn_o_inp_ptr
,
attn_ow_ptr
,
grad_attn_ow_ptr
,
grad_attn_ob_ptr
,
_cublasHandle
,
_stream
,
buf_1
);
launch_transform_0213
<
T
>
(
buf_2
,
buf_1
,
bsz
,
_seq_length
,
_hidden_size
,
_heads
,
_stream
);
if
(
_attn_prob_dropout
.
HasDropout
())
{
if
(
_attn_dropout_checkpoint
)
_attn_prob_dropout
.
Forward
(
bsz_heads
*
_seq_length
,
ctx_bufB_ptr_recomp
,
soft_out_ptr
,
_stream
,
true
);
_attn_context
.
Backward
(
bsz_heads
,
buf_2
,
v_tf_ptr
,
(
_attn_dropout_checkpoint
?
ctx_bufB_ptr_recomp
:
ctx_bufB_ptr
),
_cublasHandle
,
buf_3
,
ff2_buf
);
}
else
_attn_context
.
Backward
(
bsz_heads
,
buf_2
,
v_tf_ptr
,
soft_out_ptr
,
_cublasHandle
,
buf_3
,
ff2_buf
);
_attn_prob_dropout
.
Backward
(
bsz_heads
*
_seq_length
,
ff2_buf
,
_stream
);
_softmax
.
Backward
(
bsz
,
ff2_buf
,
soft_out_ptr
,
_stream
);
_attn_scores
.
Backward
(
bsz_heads
,
ff2_buf
,
k_tf_ptr
,
q_tf_ptr
,
_cublasHandle
,
buf_2
,
buf_1
);
launch_transform4d_0213
(
ff2_buf
,
buf_1
,
bsz
,
_heads
,
_seq_length
,
_hidden_size
,
_stream
,
3
);
if
(
_pre_or_postLayerNorm
)
_qkv_linear
.
Backward
(
bsz_seq
,
ff2_buf
,
inp_norm_ptr
,
attn_qkvw_ptr
,
grad_attn_qkvw_ptr
,
grad_attn_qkvb_ptr
,
_cublasHandle
,
_stream
,
buf_2
);
else
_qkv_linear
.
Backward
(
bsz_seq
,
ff2_buf
,
input_ptr
,
attn_qkvw_ptr
,
grad_attn_qkvw_ptr
,
grad_attn_qkvb_ptr
,
_cublasHandle
,
_stream
,
buf_2
);
if
(
_pre_or_postLayerNorm
)
{
if
(
_layer_norm
.
UseMean
())
_layer_norm
.
BackwardFusedAdd
(
bsz_seq
,
buf_2
,
buf_0
,
norm_w_ptr
,
grad_norm_w_ptr
,
grad_norm_b_ptr
,
streams
,
grad_input_ptr
,
input_ptr
);
else
_layer_norm
.
BackwardFusedAdd
(
bsz_seq
,
buf_2
,
buf_0
,
norm_w_ptr
,
norm_b_ptr
,
grad_norm_w_ptr
,
grad_norm_b_ptr
,
streams
,
grad_input_ptr
,
inp_norm_ptr
);
}
else
launch_fused_add2
<
T
>
(
grad_input_ptr
,
buf_2
,
buf_0
,
bsz
,
_seq_length
,
_hidden_size
,
_stream
);
}
template
<
typename
T
>
void
BertTransformerLayer
<
T
>::
SetTrainingMode
(
bool
training
)
{
// Dropout will be skipped when not in training model.
_attn_prob_dropout
.
SetTrainingMode
(
training
);
_attn_output_dropout
.
SetTrainingMode
(
training
);
_layer_output_dropout
.
SetTrainingMode
(
training
);
}
template
<
typename
T
>
void
BertTransformerLayer
<
T
>::
SetIntermediateBuffers
(
uint8_t
*
attn_prob_dropout_mask_ptr
,
uint8_t
*
attn_output_dropout_mask_ptr
,
uint8_t
*
layer_output_dropout_mask_ptr
,
T
*
attn_layer_norm_var
,
T
*
attn_layer_norm_mean
,
T
*
layer_norm_var
,
T
*
layer_norm_mean
)
{
_attn_prob_dropout
.
SetMask
(
attn_prob_dropout_mask_ptr
);
_attn_output_dropout
.
SetMask
(
attn_output_dropout_mask_ptr
);
_layer_output_dropout
.
SetMask
(
layer_output_dropout_mask_ptr
);
_attn_layer_norm
.
SetVar
(
attn_layer_norm_var
);
_attn_layer_norm
.
SetMean
(
attn_layer_norm_mean
);
_layer_norm
.
SetVar
(
layer_norm_var
);
_layer_norm
.
SetMean
(
layer_norm_mean
);
}
template
<
typename
T
>
void
BertTransformerLayer
<
T
>::
SetSeqLength
(
int
seq_len
)
{
_seq_length
=
seq_len
;
_softmax
.
SetSeqLength
(
_seq_length
);
_attn_prob_dropout
.
SetDimension
(
_seq_length
);
_attn_scores
.
SetConfig
(
_seq_length
,
_seq_length
,
_hidden_size
/
_heads
);
_attn_context
.
SetConfig
(
_hidden_size
/
_heads
,
_seq_length
,
_seq_length
);
}
template
<
typename
T
>
int
create_transformer_layer
(
int
layer_id
,
int
batch_size
,
int
hidden_dim
,
int
num_heads
,
int
intermediate_size
,
float
attn_dropout_ratio
,
float
hidden_dropout_ratio
,
float
layer_norm_eps
,
int
seed
,
bool
pre_or_postLayerNorm
,
bool
test_gemm
,
bool
attn_dropout_checkpoint
,
bool
normalize_invertible
,
bool
gelu_checkpoint
,
bool
stochastic_mode
)
{
Context
::
Instance
().
SetSeed
(
seed
);
Context
::
Instance
().
TestGemmFP16
(
test_gemm
,
batch_size
,
init_seq_length
,
num_heads
,
hidden_dim
/
num_heads
);
auto
layer
=
std
::
make_shared
<
BertTransformerLayer
<
T
>>
(
layer_id
,
batch_size
,
hidden_dim
,
num_heads
,
intermediate_size
,
init_seq_length
,
attn_dropout_ratio
,
hidden_dropout_ratio
,
layer_norm_eps
,
pre_or_postLayerNorm
,
Context
::
Instance
().
GetGemmAlgos
(),
attn_dropout_checkpoint
,
normalize_invertible
,
gelu_checkpoint
,
stochastic_mode
);
s_transformer_layers
[
layer_id
]
=
layer
;
std
::
string
dtype
=
(
std
::
is_same
<
T
,
__half
>::
value
)
?
"half"
:
"float"
;
std
::
cout
<<
"layer #"
<<
layer_id
<<
" is created with date type ["
<<
dtype
<<
"]."
<<
std
::
endl
;
return
0
;
}
template
<
typename
T
>
std
::
vector
<
torch
::
Tensor
>
ds_transformer_forward
(
int
layer_id
,
const
torch
::
Tensor
&
input
,
const
torch
::
Tensor
&
input_mask
,
const
torch
::
Tensor
&
attn_qkvw
,
const
torch
::
Tensor
&
attn_qkvb
,
const
torch
::
Tensor
&
attn_ow
,
const
torch
::
Tensor
&
attn_ob
,
const
torch
::
Tensor
&
attn_nw
,
const
torch
::
Tensor
&
attn_nb
,
const
torch
::
Tensor
&
inter_w
,
const
torch
::
Tensor
&
inter_b
,
const
torch
::
Tensor
&
output_w
,
const
torch
::
Tensor
&
output_b
,
const
torch
::
Tensor
&
norm_w
,
const
torch
::
Tensor
&
norm_b
,
bool
training_mode
,
bool
prelayernorm
,
bool
attn_dropout_checkpoint
,
bool
normalize_invertible
,
bool
gelu_checkpoint
)
{
CHECK_INPUT
(
input
);
CHECK_INPUT
(
input_mask
);
CHECK_INPUT
(
attn_qkvw
);
CHECK_INPUT
(
attn_qkvb
);
CHECK_INPUT
(
attn_ow
);
CHECK_INPUT
(
attn_ob
);
CHECK_INPUT
(
attn_nw
);
CHECK_INPUT
(
attn_nb
);
CHECK_INPUT
(
inter_w
);
CHECK_INPUT
(
inter_b
);
CHECK_INPUT
(
output_w
);
CHECK_INPUT
(
output_b
);
CHECK_INPUT
(
norm_w
);
CHECK_INPUT
(
norm_b
);
int
bsz
=
input
.
size
(
0
);
const
T
*
input_ptr
=
(
const
T
*
)
input
.
data_ptr
();
const
T
*
input_mask_ptr
=
(
const
T
*
)
input_mask
.
data_ptr
();
const
T
*
attn_qkvw_ptr
=
(
const
T
*
)
attn_qkvw
.
data_ptr
();
const
T
*
attn_qkvb_ptr
=
(
const
T
*
)
attn_qkvb
.
data_ptr
();
const
T
*
attn_ow_ptr
=
(
const
T
*
)
attn_ow
.
data_ptr
();
const
T
*
attn_ob_ptr
=
(
const
T
*
)
attn_ob
.
data_ptr
();
const
T
*
attn_nw_ptr
=
(
const
T
*
)
attn_nw
.
data_ptr
();
const
T
*
attn_nb_ptr
=
(
const
T
*
)
attn_nb
.
data_ptr
();
const
T
*
inter_w_ptr
=
(
const
T
*
)
inter_w
.
data_ptr
();
const
T
*
inter_b_ptr
=
(
const
T
*
)
inter_b
.
data_ptr
();
const
T
*
output_w_ptr
=
(
const
T
*
)
output_w
.
data_ptr
();
const
T
*
output_b_ptr
=
(
const
T
*
)
output_b
.
data_ptr
();
const
T
*
norm_w_ptr
=
(
const
T
*
)
norm_w
.
data_ptr
();
const
T
*
norm_b_ptr
=
(
const
T
*
)
norm_b
.
data_ptr
();
auto
output
=
torch
::
empty_like
(
input
);
T
*
out_ptr
=
(
T
*
)
output
.
data_ptr
();
auto
options
=
torch
::
TensorOptions
()
.
dtype
(
input
.
options
().
dtype
())
.
layout
(
torch
::
kStrided
)
.
device
(
torch
::
kCUDA
)
.
requires_grad
(
true
);
auto
uint8_options
=
torch
::
TensorOptions
()
.
dtype
(
torch
::
kInt8
)
.
layout
(
torch
::
kStrided
)
.
device
(
torch
::
kCUDA
)
.
requires_grad
(
false
);
std
::
shared_ptr
<
BertTransformerLayer
<
T
>>
layer
=
std
::
static_pointer_cast
<
BertTransformerLayer
<
T
>>
(
s_transformer_layers
[
layer_id
]);
int
seq_len
=
layer
->
GetSeqLength
();
if
(
input
.
size
(
1
)
!=
seq_len
)
{
seq_len
=
input
.
size
(
1
);
layer
->
SetSeqLength
(
seq_len
);
}
auto
workspace
=
torch
::
empty
({
get_workspace_size
<
T
>
(
bsz
,
seq_len
,
layer
->
GetHiddenSize
(),
layer
->
GetIntermediateSize
(),
layer
->
GetNumHeads
(),
layer
->
IsTrainingMode
(),
layer
->
GeluCheckpoint
())},
options
);
Context
::
Instance
().
SetWorkSpace
((
T
*
)
workspace
.
data_ptr
());
auto
inp_norm
=
((
prelayernorm
||
!
normalize_invertible
)
?
torch
::
empty_like
(
input
)
:
output
);
auto
add_res
=
(
normalize_invertible
?
inp_norm
:
torch
::
empty_like
(
input
));
auto
attn_o_inp
=
torch
::
empty_like
(
input
);
auto
qkv_tf
=
torch
::
empty
({(
bsz
*
seq_len
),
output_w
.
size
(
0
)
*
3
},
options
);
auto
attn_prob_dropout_mask
=
torch
::
empty
({(
bsz
*
layer
->
GetNumHeads
()
*
seq_len
),
seq_len
},
uint8_options
);
auto
attn_output_dropout_mask
=
torch
::
empty
({(
bsz
*
seq_len
),
layer
->
GetHiddenSize
()},
uint8_options
);
auto
layer_output_dropout_mask
=
torch
::
empty
({(
bsz
*
seq_len
),
layer
->
GetHiddenSize
()},
uint8_options
);
auto
attn_layer_norm_var
=
torch
::
empty
({(
bsz
*
seq_len
)},
options
);
auto
attn_layer_norm_mean
=
torch
::
empty
({(
bsz
*
seq_len
)},
options
);
auto
layer_norm_var
=
torch
::
empty
({(
bsz
*
seq_len
)},
options
);
auto
layer_norm_mean
=
torch
::
empty
({(
bsz
*
seq_len
)},
options
);
T
*
inp_norm_ptr
=
(
T
*
)
inp_norm
.
data_ptr
();
T
*
add_res_ptr
=
(
T
*
)
add_res
.
data_ptr
();
T
*
q_tf_ptr
=
(
T
*
)
qkv_tf
.
data_ptr
();
T
*
k_tf_ptr
=
q_tf_ptr
+
(
bsz
*
seq_len
*
output_w
.
size
(
0
));
//(T*)k_tf.data_ptr();
T
*
v_tf_ptr
=
k_tf_ptr
+
(
bsz
*
seq_len
*
output_w
.
size
(
0
));
//(T*)v_tf.data_ptr();
T
*
attn_o_inp_ptr
=
(
T
*
)
attn_o_inp
.
data_ptr
();
torch
::
Tensor
ff2_inp
=
torch
::
empty
({(
bsz
*
seq_len
),
output_w
.
size
(
1
)},
options
);
torch
::
Tensor
gelu_inp
=
(
gelu_checkpoint
?
ff2_inp
:
torch
::
empty
({(
bsz
*
seq_len
),
output_w
.
size
(
1
)},
options
));
auto
ff1_inp
=
torch
::
empty_like
(
input
);
T
*
ff2_inp_ptr
=
(
T
*
)
ff2_inp
.
data_ptr
();
T
*
gelu_inp_ptr
=
(
T
*
)
gelu_inp
.
data_ptr
();
T
*
ff1_inp_ptr
=
(
T
*
)
ff1_inp
.
data_ptr
();
torch
::
Tensor
soft_out
=
torch
::
empty
({(
bsz
*
layer
->
GetNumHeads
()
*
seq_len
),
seq_len
},
options
);
torch
::
Tensor
ctx_bufB
=
(
attn_dropout_checkpoint
?
soft_out
:
torch
::
empty
({(
bsz
*
layer
->
GetNumHeads
()
*
seq_len
),
seq_len
},
options
));
T
*
soft_out_ptr
=
(
T
*
)
soft_out
.
data_ptr
();
T
*
ctx_bufB_ptr
=
(
T
*
)
ctx_bufB
.
data_ptr
();
layer
->
SetTrainingMode
(
training_mode
);
layer
->
SetIntermediateBuffers
((
uint8_t
*
)
attn_prob_dropout_mask
.
data_ptr
(),
(
uint8_t
*
)
attn_output_dropout_mask
.
data_ptr
(),
(
uint8_t
*
)
layer_output_dropout_mask
.
data_ptr
(),
(
T
*
)
attn_layer_norm_var
.
data_ptr
(),
(
T
*
)
attn_layer_norm_mean
.
data_ptr
(),
(
T
*
)
layer_norm_var
.
data_ptr
(),
(
T
*
)
layer_norm_mean
.
data_ptr
());
layer
->
Forward
(
bsz
,
input_ptr
,
input_mask_ptr
,
attn_qkvw_ptr
,
attn_qkvb_ptr
,
attn_ow_ptr
,
attn_ob_ptr
,
attn_nw_ptr
,
attn_nb_ptr
,
inter_w_ptr
,
inter_b_ptr
,
output_w_ptr
,
output_b_ptr
,
norm_w_ptr
,
norm_b_ptr
,
out_ptr
,
inp_norm_ptr
,
q_tf_ptr
,
k_tf_ptr
,
v_tf_ptr
,
soft_out_ptr
,
ctx_bufB_ptr
,
attn_o_inp_ptr
,
add_res_ptr
,
ff1_inp_ptr
,
gelu_inp_ptr
,
ff2_inp_ptr
);
return
{
output
,
inp_norm
,
qkv_tf
,
soft_out
,
ctx_bufB
,
attn_o_inp
,
add_res
,
ff1_inp
,
gelu_inp
,
ff2_inp
,
attn_prob_dropout_mask
,
attn_output_dropout_mask
,
layer_output_dropout_mask
,
attn_layer_norm_var
,
attn_layer_norm_mean
,
layer_norm_var
,
layer_norm_mean
};
}
template
<
typename
T
>
std
::
vector
<
torch
::
Tensor
>
ds_transformer_backward
(
int
layer_id
,
const
torch
::
Tensor
&
grad_output
,
const
torch
::
Tensor
&
output
,
const
torch
::
Tensor
&
inp_norm
,
const
torch
::
Tensor
&
qkv_tf
,
const
torch
::
Tensor
&
soft_out
,
const
torch
::
Tensor
&
ctx_bufB
,
const
torch
::
Tensor
&
attn_o_inp
,
const
torch
::
Tensor
&
add_res
,
const
torch
::
Tensor
&
ff1_inp
,
const
torch
::
Tensor
&
gelu_inp
,
const
torch
::
Tensor
&
ff2_inp
,
const
torch
::
Tensor
&
attn_prob_dropout_mask
,
const
torch
::
Tensor
&
attn_output_dropout_mask
,
const
torch
::
Tensor
&
layer_output_dropout_mask
,
const
torch
::
Tensor
&
attn_layer_norm_var
,
const
torch
::
Tensor
&
attn_layer_norm_mean
,
const
torch
::
Tensor
&
layer_norm_var
,
const
torch
::
Tensor
&
layer_norm_mean
,
const
torch
::
Tensor
&
input
,
const
torch
::
Tensor
&
input_mask
,
const
torch
::
Tensor
&
attn_qkvw
,
const
torch
::
Tensor
&
attn_qkvb
,
const
torch
::
Tensor
&
attn_ow
,
const
torch
::
Tensor
&
attn_ob
,
const
torch
::
Tensor
&
attn_nw
,
const
torch
::
Tensor
&
attn_nb
,
const
torch
::
Tensor
&
inter_w
,
const
torch
::
Tensor
&
inter_b
,
const
torch
::
Tensor
&
output_w
,
const
torch
::
Tensor
&
output_b
,
const
torch
::
Tensor
&
norm_w
,
const
torch
::
Tensor
&
norm_b
)
{
auto
g_output
=
grad_output
.
contiguous
();
CHECK_INPUT
(
g_output
);
CHECK_INPUT
(
output
);
CHECK_INPUT
(
inp_norm
);
CHECK_INPUT
(
qkv_tf
);
CHECK_INPUT
(
add_res
);
CHECK_INPUT
(
soft_out
);
CHECK_INPUT
(
ctx_bufB
);
CHECK_INPUT
(
attn_o_inp
);
CHECK_INPUT
(
ff1_inp
);
CHECK_INPUT
(
gelu_inp
);
CHECK_INPUT
(
ff2_inp
);
CHECK_INPUT
(
input
);
CHECK_INPUT
(
input_mask
);
CHECK_INPUT
(
attn_qkvw
);
CHECK_INPUT
(
attn_qkvb
);
CHECK_INPUT
(
attn_ow
);
CHECK_INPUT
(
attn_ob
);
CHECK_INPUT
(
attn_nw
);
CHECK_INPUT
(
attn_nb
);
CHECK_INPUT
(
inter_w
);
CHECK_INPUT
(
inter_b
);
CHECK_INPUT
(
output_w
);
CHECK_INPUT
(
output_b
);
CHECK_INPUT
(
norm_w
);
CHECK_INPUT
(
norm_b
);
int
bsz
=
g_output
.
size
(
0
);
std
::
shared_ptr
<
BertTransformerLayer
<
T
>>
layer
=
std
::
static_pointer_cast
<
BertTransformerLayer
<
T
>>
(
s_transformer_layers
[
layer_id
]);
int
seq_len
=
layer
->
GetSeqLength
();
if
(
g_output
.
size
(
1
)
!=
seq_len
)
{
seq_len
=
g_output
.
size
(
1
);
layer
->
SetSeqLength
(
seq_len
);
}
auto
options
=
torch
::
TensorOptions
()
.
dtype
(
g_output
.
options
().
dtype
())
.
layout
(
torch
::
kStrided
)
.
device
(
torch
::
kCUDA
)
.
requires_grad
(
true
);
auto
workspace
=
torch
::
empty
({
get_workspace_size
<
T
>
(
bsz
,
seq_len
,
layer
->
GetHiddenSize
(),
layer
->
GetIntermediateSize
(),
layer
->
GetNumHeads
(),
layer
->
IsTrainingMode
(),
layer
->
GeluCheckpoint
())},
options
);
Context
::
Instance
().
SetWorkSpace
((
T
*
)
workspace
.
data_ptr
());
auto
grad_input
=
torch
::
empty_like
(
input
);
auto
grad_attn_qkvw
=
torch
::
empty_like
(
attn_qkvw
);
auto
grad_attn_qkvb
=
torch
::
empty_like
(
attn_qkvb
);
auto
grad_attn_ow
=
torch
::
empty_like
(
attn_ow
);
auto
grad_attn_ob
=
torch
::
empty_like
(
attn_ob
);
auto
grad_attn_nw
=
torch
::
empty_like
(
attn_nw
);
auto
grad_attn_nb
=
torch
::
empty_like
(
attn_nb
);
auto
grad_inter_w
=
torch
::
empty_like
(
inter_w
);
auto
grad_inter_b
=
torch
::
empty_like
(
inter_b
);
auto
grad_output_w
=
torch
::
empty_like
(
output_w
);
auto
grad_output_b
=
torch
::
empty_like
(
output_b
);
auto
grad_norm_w
=
torch
::
empty_like
(
norm_w
);
auto
grad_norm_b
=
torch
::
empty_like
(
norm_b
);
// inputs.
const
T
*
grad_output_ptr
=
(
const
T
*
)
g_output
.
data_ptr
();
const
T
*
input_ptr
=
(
const
T
*
)
input
.
data_ptr
();
const
T
*
output_ptr
=
(
const
T
*
)
output
.
data_ptr
();
const
T
*
inp_norm_ptr
=
(
const
T
*
)
inp_norm
.
data_ptr
();
const
T
*
q_tf_ptr
=
(
const
T
*
)
qkv_tf
.
data_ptr
();
const
T
*
add_res_ptr
=
(
const
T
*
)
add_res
.
data_ptr
();
const
T
*
k_tf_ptr
=
q_tf_ptr
+
(
bsz
*
layer
->
GetSeqLength
()
*
output_w
.
size
(
0
));
//(const T*)k_tf.data_ptr();
const
T
*
v_tf_ptr
=
k_tf_ptr
+
(
bsz
*
layer
->
GetSeqLength
()
*
output_w
.
size
(
0
));
//(const T*)v_tf.data_ptr();
const
T
*
ff1_inp_ptr
=
(
const
T
*
)
ff1_inp
.
data_ptr
();
const
T
*
gelu_inp_ptr
=
(
const
T
*
)
gelu_inp
.
data_ptr
();
const
T
*
ff2_inp_ptr
=
(
const
T
*
)
ff2_inp
.
data_ptr
();
const
T
*
ctx_bufB_ptr
=
(
const
T
*
)
ctx_bufB
.
data_ptr
();
const
T
*
soft_out_ptr
=
(
const
T
*
)
soft_out
.
data_ptr
();
const
T
*
attn_o_inp_ptr
=
(
const
T
*
)
attn_o_inp
.
data_ptr
();
const
T
*
input_mask_ptr
=
(
const
T
*
)
input_mask
.
data_ptr
();
const
T
*
attn_qkvw_ptr
=
(
const
T
*
)
attn_qkvw
.
data_ptr
();
const
T
*
attn_ow_ptr
=
(
const
T
*
)
attn_ow
.
data_ptr
();
const
T
*
attn_nw_ptr
=
(
const
T
*
)
attn_nw
.
data_ptr
();
const
T
*
attn_nb_ptr
=
(
const
T
*
)
attn_nb
.
data_ptr
();
const
T
*
inter_w_ptr
=
(
const
T
*
)
inter_w
.
data_ptr
();
const
T
*
inter_b_ptr
=
(
const
T
*
)
inter_b
.
data_ptr
();
const
T
*
output_w_ptr
=
(
const
T
*
)
output_w
.
data_ptr
();
const
T
*
norm_w_ptr
=
(
const
T
*
)
norm_w
.
data_ptr
();
const
T
*
norm_b_ptr
=
(
const
T
*
)
norm_b
.
data_ptr
();
// outputs.
T
*
grad_input_ptr
=
(
T
*
)
grad_input
.
data_ptr
();
T
*
grad_attn_qkvw_ptr
=
(
T
*
)
grad_attn_qkvw
.
data_ptr
();
T
*
grad_attn_qkvb_ptr
=
(
T
*
)
grad_attn_qkvb
.
data_ptr
();
T
*
grad_attn_ow_ptr
=
(
T
*
)
grad_attn_ow
.
data_ptr
();
T
*
grad_attn_ob_ptr
=
(
T
*
)
grad_attn_ob
.
data_ptr
();
T
*
grad_attn_nw_ptr
=
(
T
*
)
grad_attn_nw
.
data_ptr
();
T
*
grad_attn_nb_ptr
=
(
T
*
)
grad_attn_nb
.
data_ptr
();
T
*
grad_inter_w_ptr
=
(
T
*
)
grad_inter_w
.
data_ptr
();
T
*
grad_inter_b_ptr
=
(
T
*
)
grad_inter_b
.
data_ptr
();
T
*
grad_output_w_ptr
=
(
T
*
)
grad_output_w
.
data_ptr
();
T
*
grad_output_b_ptr
=
(
T
*
)
grad_output_b
.
data_ptr
();
T
*
grad_norm_w_ptr
=
(
T
*
)
grad_norm_w
.
data_ptr
();
T
*
grad_norm_b_ptr
=
(
T
*
)
grad_norm_b
.
data_ptr
();
layer
->
SetIntermediateBuffers
((
uint8_t
*
)
attn_prob_dropout_mask
.
data_ptr
(),
(
uint8_t
*
)
attn_output_dropout_mask
.
data_ptr
(),
(
uint8_t
*
)
layer_output_dropout_mask
.
data_ptr
(),
(
T
*
)
attn_layer_norm_var
.
data_ptr
(),
(
T
*
)
attn_layer_norm_mean
.
data_ptr
(),
(
T
*
)
layer_norm_var
.
data_ptr
(),
(
T
*
)
layer_norm_mean
.
data_ptr
());
layer
->
Backward
(
bsz
,
grad_output_ptr
,
input_ptr
,
output_ptr
,
inp_norm_ptr
,
q_tf_ptr
,
k_tf_ptr
,
v_tf_ptr
,
soft_out_ptr
,
ctx_bufB_ptr
,
attn_o_inp_ptr
,
add_res_ptr
,
ff1_inp_ptr
,
gelu_inp_ptr
,
ff2_inp_ptr
,
input_mask_ptr
,
attn_qkvw_ptr
,
attn_ow_ptr
,
attn_nw_ptr
,
attn_nb_ptr
,
inter_w_ptr
,
inter_b_ptr
,
output_w_ptr
,
norm_w_ptr
,
norm_b_ptr
,
grad_input_ptr
,
grad_attn_qkvw_ptr
,
grad_attn_qkvb_ptr
,
grad_attn_ow_ptr
,
grad_attn_ob_ptr
,
grad_attn_nw_ptr
,
grad_attn_nb_ptr
,
grad_inter_w_ptr
,
grad_inter_b_ptr
,
grad_output_w_ptr
,
grad_output_b_ptr
,
grad_norm_w_ptr
,
grad_norm_b_ptr
);
return
{
grad_input
,
grad_attn_qkvw
,
grad_attn_qkvb
,
grad_attn_ow
,
grad_attn_ob
,
grad_attn_nw
,
grad_attn_nb
,
grad_inter_w
,
grad_inter_b
,
grad_output_w
,
grad_output_b
,
grad_norm_w
,
grad_norm_b
};
}
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"forward_fp32"
,
&
ds_transformer_forward
<
float
>
,
"DeepSpeed Transformer forward with fp32 (CUDA)"
);
m
.
def
(
"forward_fp16"
,
&
ds_transformer_forward
<
__half
>
,
"DeepSpeed Transformer forward with fp16 (CUDA)"
);
m
.
def
(
"backward_fp32"
,
&
ds_transformer_backward
<
float
>
,
"DeepSpeed Transformer backward with fp32 (CUDA)"
);
m
.
def
(
"backward_fp16"
,
&
ds_transformer_backward
<
__half
>
,
"DeepSpeed Transformer backward with fp16 (CUDA)"
);
m
.
def
(
"create_transformer_layer_fp32"
,
&
create_transformer_layer
<
float
>
,
"Create DeepSpeed Transformer Transformer Layer with fp32 (CUDA)"
);
m
.
def
(
"create_transformer_layer_fp16"
,
&
create_transformer_layer
<
__half
>
,
"Create DeepSpeed Transformer Transformer Layer with fp16 (CUDA)"
);
}
csrc/transformer/hip/gelu_kernels.hip
0 → 100644
View file @
eadbbe09
#include "hip/hip_runtime.h"
#include "hip/custom_hip_layers.h"
inline __device__ float gelu(const float x)
{
const float sqrt_param = 0.79788456080286535587989211986876f;
const float mul_param = 0.044715;
return x * 0.5f * (1.0f + tanhf(sqrt_param * (x + mul_param * x * x * x)));
}
inline __device__ float d_gelu(const float x)
{
const float sqrt_param = 0.79788456080286535587989211986876f;
const float mul_param = 0.044715;
float x2mul = x * x * mul_param;
float tan_h = tanhf(sqrt_param * (x + x * x2mul));
float dg1 = 0.5f * (1.0f + tan_h);
float dg2 = x * 0.5f * sqrt_param * (1 - tan_h * tan_h);
float dg3 = dg2 * 3 * x2mul;
return (dg1 + dg2 + dg3);
}
/*
Fused bias add with GELU
Loads a vector of 4 elements each iteration, for stride
iterations. It was written with the intention to launch 256 thread
threadblocks, so to launch for bert-large, we would set ITERATIONS
to 4. This is currently done automatically as a heuristic, setting
the number of iterations as blocks of 1024.
For FP16, the values are loaded from memory as __half, but converted
to FP32 for the arithmetic itself, to prevent numerous overflow on
the intermediate hyperbolic tangent, since there's no intrinsic
that computes it directly.
*/
__global__ void gelu_kernel(const float* input, float* vals, int intermediate_size)
{
int row = blockIdx.x;
int id = threadIdx.x;
int loop_stride = blockDim.x;
int iterations = intermediate_size / blockDim.x / 4;
int row_stride = intermediate_size / 4;
const float4* input_cast = reinterpret_cast<const float4*>(input);
float4* vals_cast = reinterpret_cast<float4*>(vals);
for (int i = 0; i < iterations; i++) {
if (i * loop_stride + id < row_stride) {
float4 data = input_cast[row * row_stride + i * loop_stride + id];
data.x = gelu(data.x);
data.y = gelu(data.y);
data.z = gelu(data.z);
data.w = gelu(data.w);
vals_cast[row * row_stride + i * loop_stride + id] = data;
}
}
}
__global__ void gelu_kernel(const __half* input, __half* vals, int intermediate_size)
{
#if __CUDA_ARCH__ >= 700
int row = blockIdx.x;
int id = threadIdx.x;
int loop_stride = blockDim.x;
int iterations = intermediate_size / blockDim.x / 4;
int row_stride = intermediate_size / 4;
const float2* input_cast = reinterpret_cast<const float2*>(input);
float2* vals_cast = reinterpret_cast<float2*>(vals);
for (int i = 0; i < iterations; i++) {
if (i * loop_stride + id < row_stride) {
float2 vals_vec = input_cast[row * row_stride + i * loop_stride + id];
__half2* vals_half = reinterpret_cast<__half2*>(&vals_vec);
float2 low_data = __half22float2(vals_half[0]);
float2 high_data = __half22float2(vals_half[1]);
low_data.x = gelu(low_data.x);
low_data.y = gelu(low_data.y);
high_data.x = gelu(high_data.x);
high_data.y = gelu(high_data.y);
vals_half[0] = __float22half2_rn(low_data);
vals_half[1] = __float22half2_rn(high_data);
vals_cast[row * row_stride + i * loop_stride + id] = vals_vec;
}
}
#endif
}
__global__ void fused_bias_gelu(const float* input,
const float* bias,
float* vals,
int intermediate_size)
{
int row = blockIdx.x;
int id = threadIdx.x;
int loop_stride = blockDim.x;
int iterations = intermediate_size / blockDim.x / 4;
int row_stride = intermediate_size / 4;
const float4* input_cast = reinterpret_cast<const float4*>(input);
float4* vals_cast = reinterpret_cast<float4*>(vals);
const float4* bias_cast = reinterpret_cast<const float4*>(bias);
for (int i = 0; i < iterations; i++) {
if (i * loop_stride + id < row_stride) {
float4 data = input_cast[row * row_stride + i * loop_stride + id];
float4 bias_data = bias_cast[i * loop_stride + id];
data.x += bias_data.x;
data.y += bias_data.y;
data.z += bias_data.z;
data.w += bias_data.w;
data.x = gelu(data.x);
data.y = gelu(data.y);
data.z = gelu(data.z);
data.w = gelu(data.w);
vals_cast[row * row_stride + i * loop_stride + id] = data;
}
}
}
__global__ void fused_bias_gelu(const __half* input,
const __half* bias,
__half* vals,
int intermediate_size)
{
#if __CUDA_ARCH__ >= 700
int row = blockIdx.x;
int id = threadIdx.x;
int loop_stride = blockDim.x;
int iterations = intermediate_size / blockDim.x / 4;
int row_stride = intermediate_size / 4;
const float2* input_cast = reinterpret_cast<const float2*>(input);
float2* vals_cast = reinterpret_cast<float2*>(vals);
const float2* bias_cast = reinterpret_cast<const float2*>(bias);
for (int i = 0; i < iterations; i++) {
if (i * loop_stride + id < row_stride) {
float2 vals_vec = input_cast[row * row_stride + i * loop_stride + id];
float2 bias_vec = bias_cast[i * loop_stride + id];
__half2* vals_half = reinterpret_cast<__half2*>(&vals_vec);
__half2* bias_half = reinterpret_cast<__half2*>(&bias_vec);
float2 low_data = __half22float2(vals_half[0]);
float2 high_data = __half22float2(vals_half[1]);
float2 low_bias = __half22float2(bias_half[0]);
float2 high_bias = __half22float2(bias_half[1]);
low_data.x += low_bias.x;
low_data.y += low_bias.y;
high_data.x += high_bias.x;
high_data.y += high_bias.y;
low_data.x = gelu(low_data.x);
low_data.y = gelu(low_data.y);
high_data.x = gelu(high_data.x);
high_data.y = gelu(high_data.y);
vals_half[0] = __float22half2_rn(low_data);
vals_half[1] = __float22half2_rn(high_data);
vals_cast[row * row_stride + i * loop_stride + id] = vals_vec;
}
}
#endif
}
__global__ void d_gelu_func(float* d_output,
const float* gelu_input,
const float* bias,
int intermediate_size)
{
int row = blockIdx.x;
int id = threadIdx.x;
int loop_stride = blockDim.x;
int iterations = intermediate_size / blockDim.x / 4;
int row_stride = intermediate_size / 4;
float4* d_output_cast = reinterpret_cast<float4*>(d_output);
const float4* gelu_input_cast = reinterpret_cast<const float4*>(gelu_input);
const float4* bias_cast = reinterpret_cast<const float4*>(bias);
for (int i = 0; i < iterations; i++) {
if (i * loop_stride + id < row_stride) {
float4 output_data = d_output_cast[row * row_stride + i * loop_stride + id];
float4 gelu_input_data = gelu_input_cast[row * row_stride + i * loop_stride + id];
float4 bias_data = bias_cast[i * loop_stride + id];
gelu_input_data.x += bias_data.x;
gelu_input_data.y += bias_data.y;
gelu_input_data.z += bias_data.z;
gelu_input_data.w += bias_data.w;
output_data.x *= d_gelu(gelu_input_data.x);
output_data.y *= d_gelu(gelu_input_data.y);
output_data.z *= d_gelu(gelu_input_data.z);
output_data.w *= d_gelu(gelu_input_data.w);
d_output_cast[row * row_stride + i * loop_stride + id] = output_data;
}
}
}
__global__ void d_gelu_func(__half* d_output,
const __half* gelu_input,
const __half* bias,
int intermediate_size)
{
#if __CUDA_ARCH__ >= 700
int row = blockIdx.x;
int id = threadIdx.x;
int loop_stride = blockDim.x;
int iterations = intermediate_size / blockDim.x / 4;
int row_stride = intermediate_size / 4;
float2* d_output_cast = reinterpret_cast<float2*>(d_output);
const float2* gelu_input_cast = reinterpret_cast<const float2*>(gelu_input);
const float2* bias_cast = reinterpret_cast<const float2*>(bias);
#pragma unroll
for (int i = 0; i < iterations; i++) {
if (i * loop_stride + id < row_stride) {
float2 output_data = d_output_cast[row * row_stride + i * loop_stride + id];
float2 gelu_input_data = gelu_input_cast[row * row_stride + i * loop_stride + id];
float2 bias_vec = bias_cast[i * loop_stride + id];
__half2* output_data_half = reinterpret_cast<__half2*>(&output_data);
__half2* gelu_input_data_half = reinterpret_cast<__half2*>(&gelu_input_data);
__half2* bias_half = reinterpret_cast<__half2*>(&bias_vec);
float2 output_half_0 = __half22float2(output_data_half[0]);
float2 output_half_1 = __half22float2(output_data_half[1]);
float2 gelu_input_half_0 = __half22float2(gelu_input_data_half[0]);
float2 gelu_input_half_1 = __half22float2(gelu_input_data_half[1]);
float2 bias_half_0 = __half22float2(bias_half[0]);
float2 bias_half_1 = __half22float2(bias_half[1]);
gelu_input_half_0.x += bias_half_0.x;
gelu_input_half_0.y += bias_half_0.y;
gelu_input_half_1.x += bias_half_1.x;
gelu_input_half_1.y += bias_half_1.y;
output_half_0.x *= d_gelu(gelu_input_half_0.x);
output_half_0.y *= d_gelu(gelu_input_half_0.y);
output_half_1.x *= d_gelu(gelu_input_half_1.x);
output_half_1.y *= d_gelu(gelu_input_half_1.y);
float2 result;
__half2* result_half2 = reinterpret_cast<__half2*>(&result);
result_half2[0] = __float22half2_rn(output_half_0);
result_half2[1] = __float22half2_rn(output_half_1);
d_output_cast[row * row_stride + i * loop_stride + id] = result;
}
}
#endif
}
template <typename T>
void launch_bias_gelu(const T* input,
const T* bias,
T* output,
int intermediate_size,
int batch_size,
hipStream_t stream)
{
int iterations = (intermediate_size + 1023) / 1024;
int threads = intermediate_size / iterations / 4;
dim3 block_dims(threads);
dim3 grid_dims(batch_size);
hipLaunchKernelGGL(( fused_bias_gelu), dim3(grid_dims), dim3(block_dims), 0, stream, input, bias, output, intermediate_size);
}
template <typename T>
void launch_gelu(const T* input,
T* output,
int intermediate_size,
int batch_size,
hipStream_t stream)
{
int iterations = (intermediate_size + 1023) / 1024;
int threads = intermediate_size / iterations / 4;
dim3 block_dims(threads);
dim3 grid_dims(batch_size);
hipLaunchKernelGGL(( gelu_kernel), dim3(grid_dims), dim3(block_dims), 0, stream, input, output, intermediate_size);
}
template void launch_bias_gelu<float>(const float*, const float*, float*, int, int, hipStream_t);
template void launch_bias_gelu<__half>(const __half*,
const __half*,
__half*,
int,
int,
hipStream_t);
template void launch_gelu<float>(const float*, float*, int, int, hipStream_t);
template void launch_gelu<__half>(const __half*, __half*, int, int, hipStream_t);
template <typename T>
void launch_d_gelu(T* d_output,
const T* input,
const T* bias,
int intermediate_size,
int batch_size,
hipStream_t stream)
{
int iterations = (intermediate_size + 1023) / 1024;
int threads = intermediate_size / iterations / 4;
dim3 block_dims(threads);
dim3 grid_dims(batch_size);
hipLaunchKernelGGL(( d_gelu_func), dim3(grid_dims), dim3(block_dims), 0, stream, d_output, input, bias, intermediate_size);
}
template void launch_d_gelu<float>(float*, const float*, const float*, int, int, hipStream_t);
template void launch_d_gelu<__half>(__half*, const __half*, const __half*, int, int, hipStream_t);
csrc/transformer/hip/general_kernels.hip
0 → 100644
View file @
eadbbe09
#include "hip/hip_runtime.h"
#include "hip/general_kernels.h"
namespace cg = cooperative_groups;
//template <typename T>
//__global__ void column_sum_reduce(const T* __restrict__ inp,
// T* __restrict__ out,
// int rows,
// int width)
//{
// __shared__ float tile[TILE_DIM][TILE_DIM + 1];
//
// cg::thread_block b = cg::this_thread_block();
// cg::thread_block_tile<TILE_DIM> g = cg::tiled_partition<TILE_DIM>(b);
//
// int idx = blockDim.x * blockIdx.x + threadIdx.x;
//
// int y_stride = width * TILE_DIM;
//
// float localSum = 0;
//
// // Loop across matrix height
// if (idx < width) {
// int offset = threadIdx.y * width + idx;
// for (int r = threadIdx.y; r < rows; r += TILE_DIM) {
// localSum += (float)inp[offset];
// offset += y_stride;
// }
// }
//
// tile[threadIdx.x][threadIdx.y] = localSum;
//
// __syncthreads();
//
// // Sum the shared buffer.
// float sum = tile[threadIdx.y][threadIdx.x];
//
//#ifndef __STOCHASTIC_MODE__
// __syncthreads();
//#endif
//
// for (int i = 1; i < TILE_DIM; i <<= 1) sum += g.shfl_down(sum, i);
//
// if (threadIdx.x == 0) {
// int pos = blockIdx.x * TILE_DIM + threadIdx.y;
// if (pos < width) out[pos] = sum;
// }
//}
//template <typename T>
//void launch_fuse_transpose_bias_kernel(const T* inp,
// T* out,
// int rows,
// int cols,
// hipStream_t stream);
//
//template <>
//void launch_fuse_transpose_bias_kernel<float>(const float* inp,
// float* out,
// int rows,
// int cols,
// hipStream_t stream)
//{
// // assert(rows % TILE_DIM == 0);
// // assert(cols % TILE_DIM == 0);
//
// dim3 grid_dim((cols - 1) / TILE_DIM + 1);
// dim3 block_dim(TILE_DIM, TILE_DIM);
//
// hipLaunchKernelGGL(( column_sum_reduce<float>), dim3(grid_dim), dim3(block_dim), 0, stream, inp, out, rows, cols);
//}
//template <>
//void launch_fuse_transpose_bias_kernel<__half>(const __half* inp,
// __half* out,
// int rows,
// int cols,
// hipStream_t stream)
//{
// // assert(rows % TILE_DIM == 0);
// // assert(cols % TILE_DIM == 0);
//
// dim3 grid_dim((cols - 1) / TILE_DIM + 1);
// dim3 block_dim(TILE_DIM, TILE_DIM);
//
// hipLaunchKernelGGL(( column_sum_reduce<__half>), dim3(grid_dim), dim3(block_dim), 0, stream, inp, out, rows, cols);
//}
__global__ void fused_add2_kernel(const int N, float* out, const float* inp1, const float* inp2)
{
const float4* inp1_4 = reinterpret_cast<const float4*>(inp1);
const float4* inp2_4 = reinterpret_cast<const float4*>(inp2);
float4* out_4 = reinterpret_cast<float4*>(out);
CUDA_1D_KERNEL_LOOP(j, N)
{
float4 val;
float4 inp1_reg = inp1_4[j];
float4 inp2_reg = inp2_4[j];
val.x = inp1_reg.x + inp2_reg.x;
val.y = inp1_reg.y + inp2_reg.y;
val.z = inp1_reg.z + inp2_reg.z;
val.w = inp1_reg.w + inp2_reg.w;
out_4[j] = val;
}
}
__global__ void fused_add2_kernel(const int N, __half* out, const __half* inp1, const __half* inp2)
{
float2 inp1_4;
float2 inp2_4;
__half2* inp1_h = reinterpret_cast<__half2*>(&inp1_4);
__half2* inp2_h = reinterpret_cast<__half2*>(&inp2_4);
const float2* inp1_arr = reinterpret_cast<const float2*>(inp1);
const float2* inp2_arr = reinterpret_cast<const float2*>(inp2);
CUDA_1D_KERNEL_LOOP(j, N)
{
inp1_4 = inp1_arr[j];
inp2_4 = inp2_arr[j];
float2 inp1_h_f_0 = __half22float2(inp1_h[0]);
float2 inp1_h_f_1 = __half22float2(inp1_h[1]);
float2 inp2_h_f_0 = __half22float2(inp2_h[0]);
float2 inp2_h_f_1 = __half22float2(inp2_h[1]);
inp1_h_f_0.x += inp2_h_f_0.x;
inp1_h_f_0.y += inp2_h_f_0.y;
inp1_h_f_1.x += inp2_h_f_1.x;
inp1_h_f_1.y += inp2_h_f_1.y;
float2 val_f;
__half2* val_h = reinterpret_cast<__half2*>(&val_f);
val_h[0] = __float22half2_rn(inp1_h_f_0);
val_h[1] = __float22half2_rn(inp1_h_f_1);
float2* out_4 = reinterpret_cast<float2*>(out);
out_4[j] = val_f;
}
}
template <>
void launch_fused_add2<float>(float* out,
const float* inp1,
const float* inp2,
int batch_size,
int seq_length,
int hidden_dim,
hipStream_t& stream)
{
int total_count = batch_size * seq_length * hidden_dim / 4;
dim3 grid_dim = DS_GET_BLOCKS(total_count); //(batch_size * seq_length);
dim3 block_dim = DS_CUDA_NUM_THREADS; //(hidden_dim / 4);
hipLaunchKernelGGL(( fused_add2_kernel), dim3(grid_dim), dim3(block_dim), 0, stream, total_count, out, inp1, inp2);
}
template <>
void launch_fused_add2<__half>(__half* out,
const __half* inp1,
const __half* inp2,
int batch_size,
int seq_length,
int hidden_dim,
hipStream_t& stream)
{
int total_count = batch_size * seq_length * hidden_dim / 4;
dim3 grid_dim = DS_GET_BLOCKS(total_count); //(batch_size * seq_length);
dim3 block_dim = DS_CUDA_NUM_THREADS; //(hidden_dim / 4);
hipLaunchKernelGGL(( fused_add2_kernel), dim3(grid_dim), dim3(block_dim), 0, stream, total_count, out, inp1, inp2);
}
__global__ void fused_add3_kernel(float* out,
const float* inp1,
const float* inp2,
const float* inp3,
int size,
int row_stride)
{
int row = blockIdx.x;
int id = threadIdx.x;
const float4* inp1_4 = reinterpret_cast<const float4*>(inp1);
const float4* inp2_4 = reinterpret_cast<const float4*>(inp2);
const float4* inp3_4 = reinterpret_cast<const float4*>(inp3);
float4* out_4 = reinterpret_cast<float4*>(out);
float4 val;
float4 inp1_reg = inp1_4[row * row_stride + id];
float4 inp2_reg = inp2_4[row * row_stride + id];
float4 inp3_reg = inp3_4[row * row_stride + id];
val.x = inp1_reg.x + inp2_reg.x + inp3_reg.x;
val.y = inp1_reg.y + inp2_reg.y + inp3_reg.y;
val.z = inp1_reg.z + inp2_reg.z + inp3_reg.z;
val.w = inp1_reg.w + inp2_reg.w + inp3_reg.w;
out_4[row * row_stride + id] = val;
}
__global__ void fused_add3_kernel(__half* out,
const __half* inp1,
const __half* inp2,
const __half* inp3,
int size,
int row_stride)
{
int row = blockIdx.x;
int id = threadIdx.x;
const float2* inp1_arr = reinterpret_cast<const float2*>(inp1);
const float2* inp2_arr = reinterpret_cast<const float2*>(inp2);
const float2* inp3_arr = reinterpret_cast<const float2*>(inp3);
float2 inp1_4 = inp1_arr[row * row_stride + id];
float2 inp2_4 = inp2_arr[row * row_stride + id];
float2 inp3_4 = inp3_arr[row * row_stride + id];
__half2* inp1_h = reinterpret_cast<__half2*>(&inp1_4);
__half2* inp2_h = reinterpret_cast<__half2*>(&inp2_4);
__half2* inp3_h = reinterpret_cast<__half2*>(&inp3_4);
float2 inp1_h_f_0 = __half22float2(inp1_h[0]);
float2 inp1_h_f_1 = __half22float2(inp1_h[1]);
float2 inp2_h_f_0 = __half22float2(inp2_h[0]);
float2 inp2_h_f_1 = __half22float2(inp2_h[1]);
float2 inp3_h_f_0 = __half22float2(inp3_h[0]);
float2 inp3_h_f_1 = __half22float2(inp3_h[1]);
inp1_h_f_0.x += (inp2_h_f_0.x + inp3_h_f_0.x);
inp1_h_f_0.y += (inp2_h_f_0.y + inp3_h_f_0.y);
inp1_h_f_1.x += (inp2_h_f_1.x + inp3_h_f_1.x);
inp1_h_f_1.y += (inp2_h_f_1.y + inp3_h_f_1.y);
float2 val_f;
__half2* val_h = reinterpret_cast<__half2*>(&val_f);
val_h[0] = __float22half2_rn(inp1_h_f_0);
val_h[1] = __float22half2_rn(inp1_h_f_1);
float2* out_4 = reinterpret_cast<float2*>(out);
out_4[row * row_stride + id] = val_f;
}
template <>
void launch_fused_add3<float>(float* out,
const float* inp1,
const float* inp2,
const float* inp3,
int batch_size,
int seq_length,
int hidden_size,
hipStream_t& stream)
{
dim3 grid_dim(batch_size * seq_length);
dim3 block_dim(hidden_size / 4);
hipLaunchKernelGGL(( fused_add3_kernel), dim3(grid_dim), dim3(block_dim), 0, stream,
out, inp1, inp2, inp3, (batch_size * seq_length * hidden_size), hidden_size / 4);
}
template <>
void launch_fused_add3<__half>(__half* out,
const __half* inp1,
const __half* inp2,
const __half* inp3,
int batch_size,
int seq_length,
int hidden_size,
hipStream_t& stream)
{
dim3 grid_dim(batch_size * seq_length);
dim3 block_dim(hidden_size / 4);
hipLaunchKernelGGL(( fused_add3_kernel), dim3(grid_dim), dim3(block_dim), 0, stream,
out, inp1, inp2, inp3, (batch_size * seq_length * hidden_size), hidden_size / 4);
}
__global__ void fused_add4_kernel(float* out,
const float* inp1,
const float* inp2,
const float* inp3,
const float* inp4,
int size,
int row_stride)
{
int row = blockIdx.x;
int id = threadIdx.x;
const float4* inp1_4 = reinterpret_cast<const float4*>(inp1);
const float4* inp2_4 = reinterpret_cast<const float4*>(inp2);
const float4* inp3_4 = reinterpret_cast<const float4*>(inp3);
const float4* inp4_4 = reinterpret_cast<const float4*>(inp4);
float4* out_4 = reinterpret_cast<float4*>(out);
float4 val;
float4 inp1_reg = inp1_4[row * row_stride + id];
float4 inp2_reg = inp2_4[row * row_stride + id];
float4 inp3_reg = inp3_4[row * row_stride + id];
float4 inp4_reg = inp4_4[row * row_stride + id];
val.x = inp1_reg.x + inp2_reg.x + inp3_reg.x + inp4_reg.x;
val.y = inp1_reg.y + inp2_reg.y + inp3_reg.y + inp4_reg.y;
val.z = inp1_reg.z + inp2_reg.z + inp3_reg.z + inp4_reg.z;
val.w = inp1_reg.w + inp2_reg.w + inp3_reg.w + inp4_reg.w;
out_4[row * row_stride + id] = val;
}
__global__ void fused_add4_kernel(__half* out,
const __half* inp1,
const __half* inp2,
const __half* inp3,
const __half* inp4,
int size,
int row_stride)
{
int row = blockIdx.x;
int id = threadIdx.x;
const float2* inp1_arr = reinterpret_cast<const float2*>(inp1);
const float2* inp2_arr = reinterpret_cast<const float2*>(inp2);
const float2* inp3_arr = reinterpret_cast<const float2*>(inp3);
const float2* inp4_arr = reinterpret_cast<const float2*>(inp4);
float2 inp1_4 = inp1_arr[row * row_stride + id];
float2 inp2_4 = inp2_arr[row * row_stride + id];
float2 inp3_4 = inp3_arr[row * row_stride + id];
float2 inp4_4 = inp4_arr[row * row_stride + id];
__half2* inp1_h = reinterpret_cast<__half2*>(&inp1_4);
__half2* inp2_h = reinterpret_cast<__half2*>(&inp2_4);
__half2* inp3_h = reinterpret_cast<__half2*>(&inp3_4);
__half2* inp4_h = reinterpret_cast<__half2*>(&inp4_4);
float2 inp1_h_f_0 = __half22float2(inp1_h[0]);
float2 inp1_h_f_1 = __half22float2(inp1_h[1]);
float2 inp2_h_f_0 = __half22float2(inp2_h[0]);
float2 inp2_h_f_1 = __half22float2(inp2_h[1]);
float2 inp3_h_f_0 = __half22float2(inp3_h[0]);
float2 inp3_h_f_1 = __half22float2(inp3_h[1]);
float2 inp4_h_f_0 = __half22float2(inp4_h[0]);
float2 inp4_h_f_1 = __half22float2(inp4_h[1]);
inp1_h_f_0.x += (inp2_h_f_0.x + inp3_h_f_0.x + inp4_h_f_0.x);
inp1_h_f_0.y += (inp2_h_f_0.y + inp3_h_f_0.y + inp4_h_f_0.y);
inp1_h_f_1.x += (inp2_h_f_1.x + inp3_h_f_1.x + inp4_h_f_1.x);
inp1_h_f_1.y += (inp2_h_f_1.y + inp3_h_f_1.y + inp4_h_f_1.y);
float2 val_f;
__half2* val_h = reinterpret_cast<__half2*>(&val_f);
val_h[0] = __float22half2_rn(inp1_h_f_0);
val_h[1] = __float22half2_rn(inp1_h_f_1);
float2* out_4 = reinterpret_cast<float2*>(out);
out_4[row * row_stride + id] = val_f;
}
template <>
void launch_fused_add4<float>(float* out,
const float* inp1,
const float* inp2,
const float* inp3,
const float* inp4,
int batch_size,
int seq_length,
int hidden_size,
hipStream_t& stream)
{
dim3 grid_dim(batch_size * seq_length);
dim3 block_dim(hidden_size / 4);
hipLaunchKernelGGL(( fused_add4_kernel), dim3(grid_dim), dim3(block_dim), 0, stream,
out, inp1, inp2, inp3, inp4, (batch_size * seq_length * hidden_size), hidden_size / 4);
}
template <>
void launch_fused_add4<__half>(__half* out,
const __half* inp1,
const __half* inp2,
const __half* inp3,
const __half* inp4,
int batch_size,
int seq_length,
int hidden_size,
hipStream_t& stream)
{
dim3 grid_dim(batch_size * seq_length);
dim3 block_dim(hidden_size / 4);
hipLaunchKernelGGL(( fused_add4_kernel), dim3(grid_dim), dim3(block_dim), 0, stream,
out, inp1, inp2, inp3, inp4, (batch_size * seq_length * hidden_size), hidden_size / 4);
}
Prev
1
2
3
4
5
6
…
8
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment