Commit eadbbe09 authored by 401qingkong's avatar 401qingkong
Browse files

push rocm deepspeed v0.3.13

parent ab5534fc
#pragma once
#include <hip/hip_runtime_api.h>
#include <hiprand.h>
#include <memory>
#include <vector>
#include "rocblas.h"
#include "hip/hip_runtime.h"
#include "hip/dropout.h"
#include "hip/feed_forward.h"
#include "hip/gelu.h"
#include "hip/general_kernels.h"
#include "hip/normalize_layer.h"
#include "hip/softmax.h"
#include "hip/strided_batch_gemm.h"
struct BertGemmAlgos {
int m_gemm_qkv_algo;
int m_gemm_inter_algo;
int m_gemm_output_algo;
int m_gemm_batch1_algo;
int m_gemm_batch2_algo;
BertGemmAlgos()
: m_gemm_qkv_algo(-1),
m_gemm_inter_algo(-1),
m_gemm_output_algo(-1),
m_gemm_batch1_algo(-1),
m_gemm_batch2_algo(-1)
{
}
};
template <typename T>
class BertTransformerLayer {
public:
BertTransformerLayer(int layer_id,
int batch_size,
int hidden_size,
int num_heads,
int intermediate_size,
int seq_length,
float attn_dropout_ratio,
float hidden_output_dropout_ratio,
float layer_norm_eps,
bool pre_or_postLayerNorm,
const std::vector<std::array<int, 3>>& gemm_algos,
bool attn_dropout_checkpoint,
bool normalize_invertible,
bool gelu_checkpoint,
bool stochastic_mode);
virtual ~BertTransformerLayer();
void Forward(int bsz,
const T* input_ptr,
const T* input_mask_ptr,
const T* attn_qkvw_ptr,
const T* attn_qkvb_ptr,
const T* attn_ow_ptr,
const T* attn_ob_ptr,
const T* attn_nw_ptr,
const T* attn_nb_ptr,
const T* inter_w_ptr,
const T* inter_b_ptr,
const T* output_w_ptr,
const T* output_b_ptr,
const T* norm_w_ptr,
const T* norm_b_ptr,
T* out_ptr,
T* inp_norm_ptr,
T* q_tf_ptr,
T* k_tf_ptr,
T* v_tf_ptr,
T* softmax_output_ptr,
T* ctx_bufB_ptr,
T* attn_o_inp_ptr,
T* add_res_ptr,
T* ff1_inp_ptr,
T* gelu_inp_ptr,
T* ff2_inp_ptr);
void Backward(int bsz,
const T* grad_output_ptr,
const T* input_ptr,
const T* output_ptr,
const T* inp_norm_ptr,
const T* q_tf_ptr,
const T* k_tf_ptr,
const T* v_tf_ptr,
const T* softmax_output_ptr,
const T* ctx_bufB_ptr,
const T* attn_o_inp_ptr,
const T* add_res_ptr,
const T* ff1_inp_ptr,
const T* gelu_inp_ptr,
const T* ff2_inp_ptr,
const T* input_mask_ptr,
const T* attn_qkvw_ptr,
const T* attn_ow_ptr,
const T* attn_nw_ptr,
const T* attn_nb_ptr,
const T* inter_w_ptr,
const T* inter_b_ptr,
const T* output_w_ptr,
const T* norm_w_ptr,
const T* norm_b_ptr,
T* grad_input_ptr,
T* grad_attn_qkvw_ptr,
T* grad_attn_qkvb_ptr,
T* grad_attn_ow_ptr,
T* grad_attn_ob_ptr,
T* grad_attn_nw_ptr,
T* grad_attn_nb_ptr,
T* grad_inter_w_ptr,
T* grad_inter_b_ptr,
T* grad_output_w_ptr,
T* grad_output_b_ptr,
T* grad_norm_w_ptr,
T* grad_norm_b_ptr);
void SetIntermediateBuffers(uint8_t* attn_prob_dropout_mask_ptr,
uint8_t* attn_output_dropout_mask_ptr,
uint8_t* layer_output_dropout_mask_ptr,
T* layer_norm_var,
T* layer_norm_mean,
T* attn_layer_norm_var,
T* attn_layer_norm_mean);
inline int GetBatchSize() const { return _batch_size; }
inline int GetNumHeads() const { return _heads; }
inline int GetSeqLength() const { return _seq_length; }
inline int GetIntermediateSize() const { return _intermediate_size; }
void SetSeqLength(int seq_len);
inline int GetHiddenSize() const { return _hidden_size; }
void SetTrainingMode(bool training);
inline bool IsTrainingMode() const { return _training; }
inline bool GeluCheckpoint() const { return _gelu_checkpoint; }
private:
void Initialize();
size_t getWorkspaceSize(int maxBatchSize) const;
// Params
int _layer_id;
int _batch_size;
int _hidden_size;
int _heads;
int _size_per_head;
int _intermediate_size;
int _seq_length;
bool _pre_or_postLayerNorm;
rocblas_handle _cublasHandle;
hipStream_t _stream;
// layers
FeedForward<T> _qkv_linear;
FeedForward<T> _attn_out_linear;
Normalize_Layer<T> _attn_layer_norm;
Normalize_Layer<T> _layer_norm;
Normalize_Layer<T>* _last_normalize;
FeedForward<T> _ff1, _ff2;
Softmax<T> _softmax;
Gelu<T> _gelu;
Dropout<T> _attn_prob_dropout;
Dropout<T> _attn_output_dropout;
Dropout<T> _layer_output_dropout;
StridedBatchGemm<T> _attn_scores;
StridedBatchGemm<T> _attn_context;
bool _training;
// Memory saving flags
bool _attn_dropout_checkpoint;
bool _normalize_invertible;
bool _gelu_checkpoint;
// High Performace flags
bool _stochastic_mode;
};
#ifndef __FEEDFORWARD_H__
#define __FEEDFORWARD_H__
#include <hip/hip_runtime.h>
#include <hip/hip_fp16.h>
#include <stdio.h>
#include "hip/custom_hip_layers.h"
template <typename T>
class FeedForward {
public:
struct Config {
int batchSize, outputSize;
int inputSize;
std::array<int, 3> gemm_algos;
Config(int batch, int outputs, int inputs, const std::array<int, 3>& algos)
: batchSize(batch), outputSize(outputs), inputSize(inputs), gemm_algos(algos)
{
}
};
FeedForward(Config config) : config_(config) {}
~FeedForward() {}
void Forward(int bsz,
const T* input_ptr,
const T* weights,
T* out,
rocblas_handle& _cublasHandle)
{
float alpha = T(1.);
float beta = T(0.);
cublas_gemm_ex(_cublasHandle,
rocblas_operation_transpose,
rocblas_operation_none,
config_.outputSize,
bsz,
config_.inputSize,
&alpha,
&beta,
weights,
input_ptr,
out,
//cublasGemmAlgo_t(config_.gemm_algos[0]));
rocblas_gemm_algo(config_.gemm_algos[0]));
}
void Backward(int bsz,
const T* out_grad,
const T* input_ptr,
const T* weights,
T* weights_grad,
T* bias_grad,
rocblas_handle& _cublasHandle,
hipStream_t& stream,
T* inp_grad_out = nullptr,
T* out_grad_trans_out = nullptr)
{
float alpha = (T)1.0, beta = (T)0.0;
cublas_gemm_ex(_cublasHandle,
rocblas_operation_none,
rocblas_operation_transpose,
config_.inputSize,
config_.outputSize,
bsz,
&alpha,
&beta,
input_ptr,
out_grad,
weights_grad,
//cublasGemmAlgo_t(config_.gemm_algos[1]));
rocblas_gemm_algo(config_.gemm_algos[1]));
cublas_gemm_ex(_cublasHandle,
rocblas_operation_none,
rocblas_operation_none,
config_.inputSize,
bsz,
config_.outputSize,
&alpha,
&beta,
weights,
out_grad,
inp_grad_out,
//cublasGemmAlgo_t(config_.gemm_algos[2]));
rocblas_gemm_algo(config_.gemm_algos[2]));
launch_fuse_transpose_bias_kernel<T>(out_grad, bias_grad, bsz, config_.outputSize, stream);
}
private:
Config config_;
};
#endif
#pragma once
#include <hip/hip_runtime.h>
#include <hip/hip_fp16.h>
#include <stdio.h>
#include "hip/custom_hip_layers.h"
template <typename T>
class Gelu {
public:
struct Config {
uint32_t intermediate_size;
Config(uint32_t inter_size) : intermediate_size(inter_size) {}
};
Gelu(const Config& config) : _config(config) {}
virtual ~Gelu() {}
void ForwardWithBiasAdd(int bsz,
const T* input_buf,
const T* bias,
T* output,
hipStream_t stream)
{
launch_bias_gelu<T>(input_buf, bias, output, _config.intermediate_size, bsz, stream);
}
void Backward(int bsz, T* d_output, const T* input_buf, const T* bias, hipStream_t stream)
{
launch_d_gelu<T>(d_output, input_buf, bias, _config.intermediate_size, bsz, stream);
}
private:
Config _config;
};
#pragma once
#include <hip/hip_fp16.h>
#ifndef __HIP_PLATFORM_HCC__
#include <cuda_profiler_api.h>
#endif
#include <array>
#include <cstdio>
#include <cstdlib>
#include <ctime>
#include <limits>
#include <memory>
#include "StopWatch.h"
#include "cublas_wrappers.h"
template <typename T>
void check(T result, char const* const func, const char* const file, int const line)
{
if (result) {
std::cout << (std::string("CUDA runtime error: ") + +file + ":" + std::to_string(line) +
" \n");
}
}
#define check_cuda_error(val) check((val), #val, __FILE__, __LINE__)
template <typename T>
class GemmTest {
public:
GemmTest(int m, int n, int k, rocblas_operation ta, rocblas_operation tb, rocblas_handle h)
: M(m), N(n), K(k), transa(ta), transb(tb), handle(h)
{
check_cuda_error(hipMalloc((void**)&A, sizeof(T) * M * K));
check_cuda_error(hipMalloc((void**)&B, sizeof(T) * K * N));
check_cuda_error(hipMalloc((void**)&C, sizeof(T) * M * N));
}
~GemmTest()
{
check_cuda_error(hipFree(A));
check_cuda_error(hipFree(B));
check_cuda_error(hipFree(C));
}
std::array<int, 3> TestAlgo(int loops)
{
float alpha = (T)1.0f;
float beta = (T)0.0f;
int algo_fw = Run(loops, [=](int algo) {
cublas_gemm_ex(handle,
rocblas_operation_transpose,
rocblas_operation_none,
N,
M,
K,
&alpha,
&beta,
B,
A,
C,
#ifdef __HIP_PLATFORM_HCC__
static_cast<rocblas_gemm_algo>(algo));
#else
static_cast<cublasGemmAlgo_t>(algo));
#endif
});
int algo_bw1 = Run(loops, [=](int algo) {
cublas_gemm_ex(handle,
rocblas_operation_none,
rocblas_operation_transpose,
K,
N,
M,
&alpha,
&beta,
A,
C,
B,
#ifdef __HIP_PLATFORM_HCC__
static_cast<rocblas_gemm_algo>(algo));
#else
static_cast<cublasGemmAlgo_t>(algo));
#endif
});
int algo_bw2 = Run(loops, [=](int algo) {
cublas_gemm_ex(handle,
rocblas_operation_none,
rocblas_operation_none,
K,
M,
N,
&alpha,
&beta,
B,
C,
A,
#ifdef __HIP_PLATFORM_HCC__
static_cast<rocblas_gemm_algo>(algo));
#else
static_cast<cublasGemmAlgo_t>(algo));
#endif
});
return std::array<int, 3>({algo_fw, algo_bw1, algo_bw2});
}
template <typename Func>
int Run(int loops, Func f)
{
//float fast_latency = std::numeric_limits<float>::max();
float fast_latency = (std::numeric_limits<float>::max)();
int fast_algo = 0;
#ifdef __HIP_PLATFORM_HCC__
for (int algo = (int)rocblas_gemm_algo_standard;
algo <= (int)rocblas_gemm_algo_standard;
#else
for (int algo = (int)CUBLAS_GEMM_DEFAULT_TENSOR_OP;
algo <= (int)CUBLAS_GEMM_ALGO15_TENSOR_OP;
#endif
algo++) {
int warm_up = 5;
for (int i = 0; i < warm_up; ++i) f(algo);
hipDeviceSynchronize();
Stopwatch timer;
timer.Restart();
for (int i = 0; i < loops; ++i) f(algo);
hipDeviceSynchronize();
timer.Stop();
float avg_latency = (float)timer.GetTimeInSeconds() * 1000 / loops;
printf("algo-%d: %.3fms\n", algo, avg_latency);
if (avg_latency < fast_latency) {
fast_latency = avg_latency;
fast_algo = algo;
}
}
printf("fast_algo %d: %.3f ms\n", fast_algo, fast_latency);
return fast_algo;
}
private:
int M, N, K;
rocblas_handle handle;
rocblas_operation transa, transb;
T *A, *B, *C;
};
template <typename T>
class StridedGemmTest {
public:
StridedGemmTest(int b,
int m,
int n,
int k,
rocblas_operation ta,
rocblas_operation tb,
rocblas_handle h)
: bsz(b), M(m), N(n), K(k), transa(ta), transb(tb), handle(h)
{
check_cuda_error(hipMalloc((void**)&A, sizeof(T) * M * K * bsz));
check_cuda_error(hipMalloc((void**)&B, sizeof(T) * K * N * bsz));
check_cuda_error(hipMalloc((void**)&C, sizeof(T) * M * N * bsz));
}
~StridedGemmTest()
{
check_cuda_error(hipFree(A));
check_cuda_error(hipFree(B));
check_cuda_error(hipFree(C));
}
std::array<int, 3> TestAlgo(int loops)
{
float alpha = (T)1.0f;
float beta = (T)0.0f;
int algo_fw = Run(loops, [=](int algo) {
int stride_a = M * K;
int stride_b = N * K;
int stride_c = M * N;
cublas_strided_batched_gemm(handle,
M,
N,
K,
&alpha,
&beta,
A,
B,
C,
transa,
transb,
stride_a,
stride_b,
stride_c,
bsz,
#ifdef __HIP_PLATFORM_HCC__
static_cast<rocblas_gemm_algo>(algo));
#else
static_cast<cublasGemmAlgo_t>(algo));
#endif
});
int algo_bw1 = Run(loops, [=](int algo) {
int mb = (transa == rocblas_operation_transpose ? K : M);
int kb = (transa == rocblas_operation_transpose ? M : K);
int stride_a = mb * N;
int stride_b = N * kb;
int stride_c = M * K;
// B need to transpose.
rocblas_operation op_b = (transb == rocblas_operation_transpose ? rocblas_operation_none : rocblas_operation_transpose);
// Calculate d_A.
cublas_strided_batched_gemm(handle,
mb,
kb,
N,
&alpha,
&beta,
(transa == rocblas_operation_transpose ? B : C),
(transa == rocblas_operation_transpose ? C : B),
A,
rocblas_operation_none,
op_b,
stride_a,
stride_b,
stride_c,
bsz,
#ifdef __HIP_PLATFORM_HCC__
static_cast<rocblas_gemm_algo>(algo));
#else
static_cast<cublasGemmAlgo_t>(algo));
#endif
});
int algo_bw2 = Run(loops, [=](int algo) {
// A need to transpose.
rocblas_operation op_a = (transa == rocblas_operation_transpose ? rocblas_operation_none : rocblas_operation_transpose);
int stride_a = M * K;
int stride_b = M * N;
int stride_c = N * K;
// Calculate d_B.
cublas_strided_batched_gemm(handle,
K,
N,
M,
&alpha,
&beta,
A,
C,
B,
op_a,
rocblas_operation_none,
stride_a,
stride_b,
stride_c,
bsz,
#ifdef __HIP_PLATFORM_HCC__
static_cast<rocblas_gemm_algo>(algo));
#else
static_cast<cublasGemmAlgo_t>(algo));
#endif
});
return std::array<int, 3>({algo_fw, algo_bw1, algo_bw2});
}
template <typename Func>
int Run(int loops, Func f)
{
//float fast_latency = std::numeric_limits<float>::max();
float fast_latency = (std::numeric_limits<float>::max)();
int fast_algo = 0;
#ifdef __HIP_PLATFORM_HCC__
for (int algo = (int)rocblas_gemm_algo_standard;
algo <= (int)rocblas_gemm_algo_standard;
#else
for (int algo = (int)CUBLAS_GEMM_DEFAULT_TENSOR_OP;
algo <= (int)CUBLAS_GEMM_ALGO15_TENSOR_OP;
#endif
algo++) {
int warm_up = 5;
for (int i = 0; i < warm_up; ++i) f(algo);
hipDeviceSynchronize();
Stopwatch timer;
timer.Restart();
for (int i = 0; i < loops; ++i) f(algo);
hipDeviceSynchronize();
timer.Stop();
float avg_latency = (float)timer.GetTimeInSeconds() * 1000 / loops;
printf("algo-%d: %.3fms\n", algo, avg_latency);
if (avg_latency < fast_latency) {
fast_latency = avg_latency;
fast_algo = algo;
}
}
printf("fast_algo %d: %.3f ms\n", fast_algo, fast_latency);
return fast_algo;
}
private:
int bsz, M, N, K;
rocblas_handle handle;
rocblas_operation transa, transb;
T *A, *B, *C;
};
#pragma once
#include <hip/hip_fp16.h>
#ifndef __HIP_PLATFORM_HCC__
#include <cuda_profiler_api.h>
#endif
#include <array>
#include <cstdio>
#include <cstdlib>
#include <ctime>
#include <limits>
#include <memory>
#include "StopWatch.h"
#include "hip/cublas_wrappers.h"
template <typename T>
void check(T result, char const* const func, const char* const file, int const line)
{
if (result) {
std::cout << (std::string("CUDA runtime error: ") + +file + ":" + std::to_string(line) +
" \n");
}
}
#define check_cuda_error(val) check((val), #val, __FILE__, __LINE__)
template <typename T>
class GemmTest {
public:
GemmTest(int m, int n, int k, rocblas_operation ta, rocblas_operation tb, rocblas_handle h)
: M(m), N(n), K(k), transa(ta), transb(tb), handle(h)
{
check_cuda_error(hipMalloc((void**)&A, sizeof(T) * M * K));
check_cuda_error(hipMalloc((void**)&B, sizeof(T) * K * N));
check_cuda_error(hipMalloc((void**)&C, sizeof(T) * M * N));
}
~GemmTest()
{
check_cuda_error(hipFree(A));
check_cuda_error(hipFree(B));
check_cuda_error(hipFree(C));
}
std::array<int, 3> TestAlgo(int loops)
{
float alpha = (T)1.0f;
float beta = (T)0.0f;
int algo_fw = Run(loops, [=](int algo) {
cublas_gemm_ex(handle,
rocblas_operation_transpose,
rocblas_operation_none,
N,
M,
K,
&alpha,
&beta,
B,
A,
C,
static_cast<cublasGemmAlgo_t>(algo));
});
int algo_bw1 = Run(loops, [=](int algo) {
cublas_gemm_ex(handle,
rocblas_operation_none,
rocblas_operation_transpose,
K,
N,
M,
&alpha,
&beta,
A,
C,
B,
static_cast<cublasGemmAlgo_t>(algo));
});
int algo_bw2 = Run(loops, [=](int algo) {
cublas_gemm_ex(handle,
rocblas_operation_none,
rocblas_operation_none,
K,
M,
N,
&alpha,
&beta,
B,
C,
A,
static_cast<cublasGemmAlgo_t>(algo));
});
return std::array<int, 3>({algo_fw, algo_bw1, algo_bw2});
}
template <typename Func>
int Run(int loops, Func f)
{
float fast_latency = (std::numeric_limits<float>::max)();
int fast_algo = 0;
for (int algo = (int)CUBLAS_GEMM_DEFAULT_TENSOR_OP;
algo <= (int)CUBLAS_GEMM_ALGO15_TENSOR_OP;
algo++) {
int warm_up = 5;
for (int i = 0; i < warm_up; ++i) f(algo);
hipDeviceSynchronize();
Stopwatch timer;
timer.Restart();
for (int i = 0; i < loops; ++i) f(algo);
hipDeviceSynchronize();
timer.Stop();
float avg_latency = (float)timer.GetTimeInSeconds() * 1000 / loops;
printf("algo-%d: %.3fms\n", algo, avg_latency);
if (avg_latency < fast_latency) {
fast_latency = avg_latency;
fast_algo = algo;
}
}
printf("fast_algo %d: %.3f ms\n", fast_algo, fast_latency);
return fast_algo;
}
private:
int M, N, K;
rocblas_handle handle;
rocblas_operation transa, transb;
T *A, *B, *C;
};
template <typename T>
class StridedGemmTest {
public:
StridedGemmTest(int b,
int m,
int n,
int k,
rocblas_operation ta,
rocblas_operation tb,
rocblas_handle h)
: bsz(b), M(m), N(n), K(k), transa(ta), transb(tb), handle(h)
{
check_cuda_error(hipMalloc((void**)&A, sizeof(T) * M * K * bsz));
check_cuda_error(hipMalloc((void**)&B, sizeof(T) * K * N * bsz));
check_cuda_error(hipMalloc((void**)&C, sizeof(T) * M * N * bsz));
}
~StridedGemmTest()
{
check_cuda_error(hipFree(A));
check_cuda_error(hipFree(B));
check_cuda_error(hipFree(C));
}
std::array<int, 3> TestAlgo(int loops)
{
float alpha = (T)1.0f;
float beta = (T)0.0f;
int algo_fw = Run(loops, [=](int algo) {
int stride_a = M * K;
int stride_b = N * K;
int stride_c = M * N;
cublas_strided_batched_gemm(handle,
M,
N,
K,
&alpha,
&beta,
A,
B,
C,
transa,
transb,
stride_a,
stride_b,
stride_c,
bsz,
static_cast<cublasGemmAlgo_t>(algo));
});
int algo_bw1 = Run(loops, [=](int algo) {
int mb = (transa == rocblas_operation_transpose ? K : M);
int kb = (transa == rocblas_operation_transpose ? M : K);
int stride_a = mb * N;
int stride_b = N * kb;
int stride_c = M * K;
// B need to transpose.
rocblas_operation op_b = (transb == rocblas_operation_transpose ? rocblas_operation_none : rocblas_operation_transpose);
// Calculate d_A.
cublas_strided_batched_gemm(handle,
mb,
kb,
N,
&alpha,
&beta,
(transa == rocblas_operation_transpose ? B : C),
(transa == rocblas_operation_transpose ? C : B),
A,
rocblas_operation_none,
op_b,
stride_a,
stride_b,
stride_c,
bsz,
static_cast<cublasGemmAlgo_t>(algo));
});
int algo_bw2 = Run(loops, [=](int algo) {
// A need to transpose.
rocblas_operation op_a = (transa == rocblas_operation_transpose ? rocblas_operation_none : rocblas_operation_transpose);
int stride_a = M * K;
int stride_b = M * N;
int stride_c = N * K;
// Calculate d_B.
cublas_strided_batched_gemm(handle,
K,
N,
M,
&alpha,
&beta,
A,
C,
B,
op_a,
rocblas_operation_none,
stride_a,
stride_b,
stride_c,
bsz,
static_cast<cublasGemmAlgo_t>(algo));
});
return std::array<int, 3>({algo_fw, algo_bw1, algo_bw2});
}
template <typename Func>
int Run(int loops, Func f)
{
float fast_latency = (std::numeric_limits<float>::max)();
int fast_algo = 0;
for (int algo = (int)CUBLAS_GEMM_DEFAULT_TENSOR_OP;
algo <= (int)CUBLAS_GEMM_ALGO15_TENSOR_OP;
algo++) {
int warm_up = 5;
for (int i = 0; i < warm_up; ++i) f(algo);
hipDeviceSynchronize();
Stopwatch timer;
timer.Restart();
for (int i = 0; i < loops; ++i) f(algo);
hipDeviceSynchronize();
timer.Stop();
float avg_latency = (float)timer.GetTimeInSeconds() * 1000 / loops;
printf("algo-%d: %.3fms\n", algo, avg_latency);
if (avg_latency < fast_latency) {
fast_latency = avg_latency;
fast_algo = algo;
}
}
printf("fast_algo %d: %.3f ms\n", fast_algo, fast_latency);
return fast_algo;
}
private:
int bsz, M, N, K;
rocblas_handle handle;
rocblas_operation transa, transb;
T *A, *B, *C;
};
#include <hip/hip_runtime.h>
#include <hip/hip_fp16.h>
#include <stdio.h>
#include <stdlib.h>
#include <hip/hip_cooperative_groups.h>
#include <hiprand_kernel.h>
#include "hip/context.h"
#include "hip/cublas_wrappers.h"
#define THREADS 256
#define TILE_DIM 32
#define minus_infinity -1 * std::numeric_limits<float>::infinity()
#define FINAL_MASK 0xffffffff
template <typename T>
void launch_fused_add2(T* out,
const T* inp1,
const T* inp2,
int batch_size,
int seq_length,
int hidden_size,
hipStream_t& stream);
template <typename T>
void launch_fused_add4(T* out,
const T* inp1,
const T* inp2,
const T* inp3,
const T* inp4,
int batch_size,
int seq_length,
int hidden_size,
hipStream_t& stream);
template <typename T>
void launch_fused_add3(T* out,
const T* inp1,
const T* inp2,
const T* inp3,
int batch_size,
int seq_length,
int hidden_size,
hipStream_t& stream);
#pragma once
#include <hip/hip_runtime.h>
#include <hip/hip_fp16.h>
#include <stdio.h>
#include <fstream>
#include "hip/custom_hip_layers.h"
using namespace std;
template <typename T>
class Normalize_Layer {
public:
struct Config {
uint32_t batchSize;
uint32_t seqLength;
uint32_t hiddenDim;
float epsilon;
bool training;
bool useMean;
Config(uint32_t batch,
uint32_t seq,
uint32_t h,
float epsilon = 1e-12,
bool training = true,
bool useMean = true)
: batchSize(batch),
seqLength(seq),
hiddenDim(h),
epsilon(epsilon),
training(training),
useMean(useMean)
{
}
};
Normalize_Layer(Config config)
: config_(config), vars(nullptr), means(nullptr), vals_hat(nullptr)
{
}
~Normalize_Layer() {}
void ForwardCheckpoint(int bsz, // batch * seq
T* vals,
const T* residual,
const T* gamma,
const T* betta,
hipStream_t& stream,
bool preLayerNorm = false)
{
launch_bias_residual_layer_norm(vals,
residual,
gamma,
betta,
config_.epsilon,
bsz,
config_.hiddenDim,
stream,
preLayerNorm,
config_.training,
vars,
means);
}
void Forward(int bsz,
T* vals,
const T* residual,
const T* gamma,
const T* betta,
hipStream_t& stream,
bool preLayerNorm = false)
{
launch_bias_residual_layer_norm(vals,
residual,
gamma,
betta,
config_.epsilon,
bsz,
config_.hiddenDim,
stream,
preLayerNorm,
config_.training,
vars);
}
void Backward(int bsz,
const T* out_grad,
const T* gamma,
T* gamma_grad,
T* betta_grad,
hipStream_t stream[2],
T* inp_grad_out,
const T* norm_in = nullptr)
{
launch_layerNorm_backward(out_grad,
norm_in,
vars,
means,
gamma,
gamma_grad,
betta_grad,
inp_grad_out,
bsz,
config_.hiddenDim,
stream);
}
void Backward(int bsz,
const T* out_grad,
const T* gamma,
const T* betta,
T* gamma_grad,
T* betta_grad,
hipStream_t stream[2],
T* inp_grad_out,
const T* norm_out)
{
launch_layerNorm_backward(out_grad,
norm_out,
vars,
gamma,
gamma_grad,
betta_grad,
inp_grad_out,
bsz,
config_.hiddenDim,
stream,
!config_.useMean,
betta);
}
void BackwardFusedAdd(int bsz,
const T* out_grad1,
const T* out_grad2,
const T* gamma,
T* gamma_grad,
T* betta_grad,
hipStream_t stream[2],
T* inp_grad_out,
const T* norm_in = nullptr)
{
launch_layerNorm_backward_fused_add(out_grad1,
out_grad2,
norm_in,
vars,
means,
gamma,
gamma_grad,
betta_grad,
inp_grad_out,
bsz,
config_.hiddenDim,
stream);
}
void BackwardFusedAdd(int bsz,
const T* out_grad1,
const T* out_grad2,
const T* gamma,
const T* betta,
T* gamma_grad,
T* betta_grad,
hipStream_t stream[2],
T* inp_grad_out,
const T* norm_out)
{
launch_layerNorm_backward_fused_add(out_grad1,
out_grad2,
norm_out,
vars,
gamma,
gamma_grad,
betta_grad,
inp_grad_out,
bsz,
config_.hiddenDim,
stream,
!config_.useMean,
betta);
}
inline bool UseMean() const { return config_.useMean; }
inline void SetVar(T* variance)
{
if (!variance) { throw std::runtime_error("Normalize variance is null."); }
vars = variance;
}
inline void SetMean(T* mean)
{
if (!mean) { throw std::runtime_error("Normalize mean is null."); }
means = mean;
}
private:
Config config_;
T* vars;
T* means;
T* vals_hat;
};
#pragma once
#include <hip/hip_runtime.h>
#include <hip/hip_fp16.h>
#include <stdio.h>
#include "hip/custom_hip_layers.h"
#include <fstream>
using namespace std;
template <typename T>
class Softmax {
public:
struct Config {
size_t batchSize;
size_t heads;
size_t seq_length;
size_t prob_depth;
float temprature;
bool mem_alloc;
Config(size_t batch, size_t h, size_t seq, int prob_size = 0, bool mem_alloc = false)
: batchSize(batch),
heads(h),
seq_length(seq),
prob_depth(prob_size),
temprature(1.0),
mem_alloc(mem_alloc)
{
}
};
Softmax(Config config) : config_(config) {}
~Softmax() {}
void Forward(int bsz, T* vals, const T* attn_mask, hipStream_t& stream)
{
launch_attn_softmax<T>(vals, attn_mask, bsz, config_.heads, config_.seq_length, stream);
}
void Backward(int bsz, T* out_grad, const T* soft_out, hipStream_t stream)
{
launch_attn_softmax_backward_v2<T>(
out_grad, soft_out, bsz, config_.heads, config_.seq_length, stream);
}
inline size_t GetProbDepth() const { return config_.prob_depth; }
inline size_t GetBatchSize() const { return config_.batchSize; }
inline size_t GetNumHeads() const { return config_.heads; }
inline size_t GetSeqLength() const { return config_.seq_length; }
inline void SetSeqLength(size_t seq_len) { config_.seq_length = seq_len; }
private:
Config config_;
};
#pragma once
#include <hip/hip_runtime.h>
#include <hip/hip_fp16.h>
#include <stdio.h>
#include "hip/context.h"
template <typename T>
class StridedBatchGemm {
public:
struct Config {
int batch_size;
int m;
int n;
int k;
float alpha;
float beta;
rocblas_operation op_A;
rocblas_operation op_B;
std::array<int, 3> gemm_algos;
Config(int batch,
int mm,
int nn,
int kk,
float param_alpha,
float param_beta,
rocblas_operation opA,
rocblas_operation opB,
const std::array<int, 3>& algos)
: batch_size(batch),
m(mm),
n(nn),
k(kk),
alpha(param_alpha),
beta(param_beta),
op_A(opA),
op_B(opB),
gemm_algos(algos)
{
}
void SetConfig(int mm, int nn, int kk)
{
m = mm;
n = nn;
k = kk;
}
};
StridedBatchGemm(const Config& config) : _config(config) {}
virtual ~StridedBatchGemm() {}
void Forward(int bsz, T* output, const T* _buffer_a, const T* _buffer_b, rocblas_handle handle)
{
int stride_a = _config.m * _config.k;
int stride_b = _config.n * _config.k;
int stride_c = _config.m * _config.n;
cublas_strided_batched_gemm(handle,
//rocblas_sgemm_strided_batched(handle,
_config.m,
_config.n,
_config.k,
&_config.alpha,
&_config.beta,
_buffer_a,
_buffer_b,
output,
_config.op_A,
_config.op_B,
stride_a,
stride_b,
stride_c,
bsz,
rocblas_gemm_algo(_config.gemm_algos[0]));
//rocblas_sgemm_strided_batched(handle,
}
void ForwardPlusSave(T* output, const T* _buffer_a, const T* _buffer_b, rocblas_handle handle)
{
int stride_a = _config.m * _config.k;
int stride_b = _config.n * _config.k;
int stride_c = _config.m * _config.n;
cublas_strided_batched_gemm(handle,
_config.m,
_config.n,
_config.k,
&_config.alpha,
&_config.beta,
_buffer_a,
_buffer_b,
output,
_config.op_A,
_config.op_B,
stride_a,
stride_b,
stride_c,
_config.batch_size,
//cublasGemmAlgo_t(_config.gemm_algos[0]));
rocblas_gemm_algo(_config.gemm_algos[0]));
k_buf = _buffer_a;
q_buf = _buffer_b;
}
void Backward(int bsz,
const T* d_output,
const T* _buffer_a,
const T* _buffer_b,
rocblas_handle handle,
T* inpGradA = nullptr,
T* inpGradB = nullptr)
{
int mb = (_config.op_A == rocblas_operation_transpose ? _config.k : _config.m);
int kb = (_config.op_A == rocblas_operation_transpose ? _config.m : _config.k);
int stride_a = mb * _config.n;
int stride_b = _config.n * kb;
int stride_c = _config.m * _config.k;
// B need to transpose.
rocblas_operation op_b = (_config.op_B == rocblas_operation_transpose ? rocblas_operation_none : rocblas_operation_transpose);
// Calculate d_A.
cublas_strided_batched_gemm(handle,
//rocblas_sgemm_strided_batched(handle,
mb,
kb,
_config.n,
&_config.alpha,
&_config.beta,
(_config.op_A == rocblas_operation_transpose ? _buffer_b : d_output),
(_config.op_A == rocblas_operation_transpose ? d_output : _buffer_b),
inpGradA,
rocblas_operation_none,
op_b,
stride_a,
stride_b,
stride_c,
bsz,
//cublasGemmAlgo_t(_config.gemm_algos[1]));
rocblas_gemm_algo(_config.gemm_algos[1]));
// A need to transpose.
rocblas_operation op_a = (_config.op_A == rocblas_operation_transpose ? rocblas_operation_none : rocblas_operation_transpose);
stride_a = _config.m * _config.k;
stride_b = _config.m * _config.n;
stride_c = _config.n * _config.k;
// Calculate d_B.
cublas_strided_batched_gemm(handle,
//rocblas_sgemm_strided_batched(handle,
_config.k,
_config.n,
_config.m,
&_config.alpha,
&_config.beta,
_buffer_a,
d_output,
inpGradB,
op_a,
rocblas_operation_none,
stride_a,
stride_b,
stride_c,
bsz,
//cublasGemmAlgo_t(_config.gemm_algos[2]));
rocblas_gemm_algo(_config.gemm_algos[2]));
}
inline int GetN() const { return _config.k; }
inline const T* GetBufferA() const { return k_buf; }
inline const T* GetBufferB() const { return q_buf; }
inline void SetConfig(int m, int n, int k) { _config.SetConfig(m, n, k); }
private:
Config _config;
const T* q_buf;
const T* k_buf;
};
#pragma once
#include <hip/hip_runtime.h>
#include <hip/hip_fp16.h>
#include <stdio.h>
#include "hip/context.h"
template <typename T>
class StridedBatchGemm {
public:
struct Config {
int batch_size;
int m;
int n;
int k;
float alpha;
float beta;
rocblas_operation op_A;
rocblas_operation op_B;
std::array<int, 3> gemm_algos;
Config(int batch,
int mm,
int nn,
int kk,
float param_alpha,
float param_beta,
rocblas_operation opA,
rocblas_operation opB,
const std::array<int, 3>& algos)
: batch_size(batch),
m(mm),
n(nn),
k(kk),
alpha(param_alpha),
beta(param_beta),
op_A(opA),
op_B(opB),
gemm_algos(algos)
{
}
void SetConfig(int mm, int nn, int kk)
{
m = mm;
n = nn;
k = kk;
}
};
StridedBatchGemm(const Config& config) : _config(config) {}
virtual ~StridedBatchGemm() {}
void Forward(int bsz, T* output, const T* _buffer_a, const T* _buffer_b, rocblas_handle handle)
{
int stride_a = _config.m * _config.k;
int stride_b = _config.n * _config.k;
int stride_c = _config.m * _config.n;
cublas_strided_batched_gemm(handle,
_config.m,
_config.n,
_config.k,
&_config.alpha,
&_config.beta,
_buffer_a,
_buffer_b,
output,
_config.op_A,
_config.op_B,
stride_a,
stride_b,
stride_c,
bsz,
cublasGemmAlgo_t(_config.gemm_algos[0]));
}
void ForwardPlusSave(T* output, const T* _buffer_a, const T* _buffer_b, rocblas_handle handle)
{
int stride_a = _config.m * _config.k;
int stride_b = _config.n * _config.k;
int stride_c = _config.m * _config.n;
cublas_strided_batched_gemm(handle,
_config.m,
_config.n,
_config.k,
&_config.alpha,
&_config.beta,
_buffer_a,
_buffer_b,
output,
_config.op_A,
_config.op_B,
stride_a,
stride_b,
stride_c,
_config.batch_size,
cublasGemmAlgo_t(_config.gemm_algos[0]));
k_buf = _buffer_a;
q_buf = _buffer_b;
}
void Backward(int bsz,
const T* d_output,
const T* _buffer_a,
const T* _buffer_b,
rocblas_handle handle,
T* inpGradA = nullptr,
T* inpGradB = nullptr)
{
int mb = (_config.op_A == rocblas_operation_transpose ? _config.k : _config.m);
int kb = (_config.op_A == rocblas_operation_transpose ? _config.m : _config.k);
int stride_a = mb * _config.n;
int stride_b = _config.n * kb;
int stride_c = _config.m * _config.k;
// B need to transpose.
rocblas_operation op_b = (_config.op_B == rocblas_operation_transpose ? rocblas_operation_none : rocblas_operation_transpose);
// Calculate d_A.
cublas_strided_batched_gemm(handle,
mb,
kb,
_config.n,
&_config.alpha,
&_config.beta,
(_config.op_A == rocblas_operation_transpose ? _buffer_b : d_output),
(_config.op_A == rocblas_operation_transpose ? d_output : _buffer_b),
inpGradA,
rocblas_operation_none,
op_b,
stride_a,
stride_b,
stride_c,
bsz,
cublasGemmAlgo_t(_config.gemm_algos[1]));
// A need to transpose.
rocblas_operation op_a = (_config.op_A == rocblas_operation_transpose ? rocblas_operation_none : rocblas_operation_transpose);
stride_a = _config.m * _config.k;
stride_b = _config.m * _config.n;
stride_c = _config.n * _config.k;
// Calculate d_B.
cublas_strided_batched_gemm(handle,
_config.k,
_config.n,
_config.m,
&_config.alpha,
&_config.beta,
_buffer_a,
d_output,
inpGradB,
op_a,
rocblas_operation_none,
stride_a,
stride_b,
stride_c,
bsz,
cublasGemmAlgo_t(_config.gemm_algos[2]));
}
inline int GetN() const { return _config.k; }
inline const T* GetBufferA() const { return k_buf; }
inline const T* GetBufferB() const { return q_buf; }
inline void SetConfig(int m, int n, int k) { _config.SetConfig(m, n, k); }
private:
Config _config;
const T* q_buf;
const T* k_buf;
};
#include "hip/hip_runtime.h"
/* Taken from NVIDIA/apex commit 855808f3fc268e9715d613f3c2e56469d8c986d8 */
#include <ATen/ATen.h>
// Forward/backward compatiblity hack around
// https://github.com/pytorch/pytorch/commit/3aeb78079bcd68282fe9117088e138b77318e288
// pending more future-proof guidance from upstream.
// struct TypeShim
// {
// const at::Type& payload;
// TypeShim(const at::Type& type) : payload(type) {}
// // Enable trivial conversion to a const at::Type& for pre-3aeb78
// operator const at::Type&(){ return payload; };
// // Enable dispatch switch statements to take *this directly for post-3aeb78
// //operator at::ScalarType(){ return payload.; };
// };
#define DISPATCH_FLOAT_AND_HALF(TYPE, LEVEL, NAME, ...) \
switch (TYPE) { \
case at::ScalarType::Float: { \
using scalar_t_##LEVEL = float; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Half: { \
using scalar_t_##LEVEL = at::Half; \
__VA_ARGS__; \
break; \
} \
default: AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
}
#define DISPATCH_DOUBLE_FLOAT_AND_HALF(TYPE, LEVEL, NAME, ...) \
switch (TYPE) { \
case at::ScalarType::Double: { \
using scalar_t_##LEVEL = double; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Float: { \
using scalar_t_##LEVEL = float; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Half: { \
using scalar_t_##LEVEL = at::Half; \
__VA_ARGS__; \
break; \
} \
default: AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
}
#define DISPATCH_DOUBLE_AND_FLOAT(TYPE, LEVEL, NAME, ...) \
switch (TYPE) { \
case at::ScalarType::Double: { \
using scalar_t_##LEVEL = double; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Float: { \
using scalar_t_##LEVEL = float; \
__VA_ARGS__; \
break; \
} \
default: AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
}
template <typename T>
__device__ __forceinline__ T
reduce_block_into_lanes(T* x,
T val,
int lanes = 1,
bool share_result = false) // lanes is intended to be <= 32.
{
int tid = threadIdx.x + threadIdx.y * blockDim.x;
int blockSize = blockDim.x * blockDim.y; // blockSize is intended to be a multiple of 32.
if (blockSize >= 64) {
x[tid] = val;
__syncthreads();
}
#pragma unroll
for (int i = (blockSize >> 1); i >= 64; i >>= 1) {
if (tid < i) x[tid] = x[tid] + x[tid + i];
__syncthreads();
}
T final;
if (tid < 32) {
if (blockSize >= 64)
final = x[tid] + x[tid + 32];
else
final = val;
// __SYNCWARP();
#pragma unroll
for (int i = 16; i >= lanes; i >>= 1)
final = final + __shfl_down_sync(0xffffffff, final, i);
}
if (share_result) {
if (tid < lanes) x[tid] = final; // EpilogueOp
// Make sure the smem result is visible to all warps.
__syncthreads();
}
return final;
}
/* Copyright 2019 The Microsoft DeepSpeed Team */
#include <torch/extension.h>
// CUDA forward declaration
void fused_lamb_cuda(at::Tensor& p,
at::Tensor& p_copy,
at::Tensor& m,
at::Tensor& v,
at::Tensor& g,
float lr,
float beta1,
float beta2,
float max_coeff,
float min_coeff,
float eps,
float grad_scale,
int step,
int mode,
int bias_correction,
float decay,
at::Tensor& w_l2_i,
at::Tensor& u_l2_i,
at::Tensor& lamb_coeff_val);
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) \
CHECK_CUDA(x); \
CHECK_CONTIGUOUS(x)
// C++ interface
at::Tensor lamb(at::Tensor& p,
at::Tensor& p_copy,
at::Tensor& m,
at::Tensor& v,
at::Tensor& g,
float lr,
float beta1,
float beta2,
float max_coeff,
float min_coeff,
float eps,
float grad_scale,
int step,
int mode,
int bias_correction,
float decay)
{
CHECK_INPUT(p);
if (p_copy.numel() > 0) CHECK_INPUT(p_copy);
CHECK_INPUT(m);
CHECK_INPUT(v);
CHECK_INPUT(g);
int64_t num_elem = p.numel();
AT_ASSERTM(m.numel() == num_elem, "number of elements in m and p tensors should be equal");
AT_ASSERTM(v.numel() == num_elem, "number of elements in v and p tensors should be equal");
AT_ASSERTM(g.numel() == num_elem, "number of elements in g and p tensors should be equal");
AT_ASSERTM(
p_copy.numel() == num_elem || p_copy.numel() == 0,
"number of elements in p_copy and p tensors should be equal, or p_copy should be empty");
// intermediate for weight L2 reduction
// make sure that the threads per block is at least 512 during the kernel launch otherwise the
// behavious is unexpected
at::Tensor w_l2_i = at::empty(
{512},
p.options().dtype(p.type().scalarType() == at::ScalarType::Half ? at::ScalarType::Float
: p.type().scalarType()));
// intermediate for update L2 reduction
// make sure that the threads per block is at least 512 during the kernel launch otherwise the
// behavious is unexpected
at::Tensor u_l2_i = at::empty(
{512},
p.options().dtype(p.type().scalarType() == at::ScalarType::Half ? at::ScalarType::Float
: p.type().scalarType()));
at::Tensor lamb_coeff_val = at::empty(
{1},
p.options().dtype(p.type().scalarType() == at::ScalarType::Half ? at::ScalarType::Float
: p.type().scalarType()));
fused_lamb_cuda(p,
p_copy,
m,
v,
g,
lr,
beta1,
beta2,
max_coeff,
min_coeff,
eps,
grad_scale,
step,
mode,
bias_correction,
decay,
w_l2_i,
u_l2_i,
lamb_coeff_val);
return lamb_coeff_val;
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
{
m.def("lamb", &lamb, "Adam optimized CUDA implementation with LAMB.");
}
/* Copyright 2019 The Microsoft DeepSpeed Team */
#include <hip/hip_runtime.h>
#include <hip/hip_runtime.h>
#include <stdio.h>
#include <cmath>
#include "ATen/ATen.h"
#include "ATen/TensorUtils.h"
#include "ATen/hip/HIPContext.h"
#include "ATen/hip/detail/IndexUtils.cuh"
//#include "ATen/Type.h"
#include <THH/THHGeneral.h>
#include "ATen/AccumulateType.h"
#include <iostream>
//#include <helper_functions.h>
#if defined(__HIP_PLATFORM_HCC__) && HIP_VERSION > 305
#include <hip/hip_cooperative_groups.h>
#else
#include <cooperative_groups.h>
#endif
#include <hip/hip_runtime_api.h>
#include <stdio.h>
namespace cg = cooperative_groups;
// Utility class used to avoid linker errors with extern
// unsized shared memory arrays with templated type
namespace {
// This is the un-specialized struct. Note that we prevent instantiation of this
// struct by putting an undefined symbol in the function body so it won't compile.
template <typename T>
struct SharedMemory {
// Ensure that we won't compile any un-specialized types
__device__ inline operator T*()
{
extern __device__ void error(void);
error();
return NULL;
}
};
template <>
struct SharedMemory<float> {
__device__ inline operator float*()
{
HIP_DYNAMIC_SHARED( float, s_float)
return s_float;
}
};
template <>
struct SharedMemory<double> {
__device__ inline operator double*()
{
HIP_DYNAMIC_SHARED( double, s_double)
return s_double;
}
};
} // namespace
#include "hip/type_shim.h"
//#include "type_shim.h"
typedef enum {
ADAM_MODE_0 = 0, // eps under square root
ADAM_MODE_1 = 1 // eps outside square root
} adamMode_t;
// s_a and s_b are in shared memory
// g_a and g_b are in shared memory
template <typename T, int blockSize>
__device__ void reduce_block_in_shared_memory(T* s_a, T* s_b, T* g_a, T* g_b)
{
// Handle to thread block group
cg::thread_block cta = cg::this_thread_block();
// perform block reduction in shared memory,
unsigned int tid = cta.thread_rank();
T a_sum = s_a[tid];
T b_sum = s_b[tid];
#if defined(__HIP_PLATFORM_HCC__) && HIP_VERSION > 305
cta.sync();
#else
cg::sync(cta);
#endif
// do reduction in shared mem
if ((blockSize >= 512) && (tid < 256)) {
s_a[tid] = a_sum = a_sum + s_a[tid + 256];
s_b[tid] = b_sum = b_sum + s_b[tid + 256];
}
#if defined(__HIP_PLATFORM_HCC__) && HIP_VERSION > 305
cta.sync();
#else
cg::sync(cta);
#endif
if ((blockSize >= 256) && (tid < 128)) {
s_a[tid] = a_sum = a_sum + s_a[tid + 128];
s_b[tid] = b_sum = b_sum + s_b[tid + 128];
}
#if defined(__HIP_PLATFORM_HCC__) && HIP_VERSION > 305
cta.sync();
#else
cg::sync(cta);
#endif
if ((blockSize >= 128) && (tid < 64)) {
s_a[tid] = a_sum = a_sum + s_a[tid + 64];
s_b[tid] = b_sum = b_sum + s_b[tid + 64];
}
#if defined(__HIP_PLATFORM_HCC__) && HIP_VERSION > 305
cta.sync();
#else
cg::sync(cta);
#endif
#if (__CUDA_ARCH__ >= 300)
if (tid < 32) {
cg::coalesced_group active = cg::coalesced_threads();
// Fetch final intermediate sum from 2nd warp
if (blockSize >= 64) {
a_sum = a_sum + s_a[tid + 32];
b_sum = b_sum + s_b[tid + 32];
}
// Reduce final warp using shuffle
for (int offset = warpSize / 2; offset > 0; offset /= 2) {
a_sum += active.shfl_down(a_sum, offset);
b_sum += active.shfl_down(b_sum, offset);
}
}
#else
if ((blockSize >= 64) && (tid < 32)) {
s_a[tid] = a_sum = a_sum + s_a[tid + 32];
s_b[tid] = b_sum = b_sum + s_b[tid + 32];
}
#if defined(__HIP_PLATFORM_HCC__) && HIP_VERSION > 305
cta.sync();
#else
cg::sync(cta);
#endif
if ((blockSize >= 32) && (tid < 16)) {
s_a[tid] = a_sum = a_sum + s_a[tid + 16];
s_b[tid] = b_sum = b_sum + s_b[tid + 16];
}
#if defined(__HIP_PLATFORM_HCC__) && HIP_VERSION > 305
cta.sync();
#else
cg::sync(cta);
#endif
if ((blockSize >= 16) && (tid < 8)) {
s_a[tid] = a_sum = a_sum + s_a[tid + 8];
s_b[tid] = b_sum = b_sum + s_b[tid + 8];
}
#if defined(__HIP_PLATFORM_HCC__) && HIP_VERSION > 305
cta.sync();
#else
cg::sync(cta);
#endif
if ((blockSize >= 8) && (tid < 4)) {
s_a[tid] = a_sum = a_sum + s_a[tid + 4];
s_b[tid] = b_sum = b_sum + s_b[tid + 4];
}
#if defined(__HIP_PLATFORM_HCC__) && HIP_VERSION > 305
cta.sync();
#else
cg::sync(cta);
#endif
if ((blockSize >= 4) && (tid < 2)) {
s_a[tid] = a_sum = a_sum + s_a[tid + 2];
s_b[tid] = b_sum = b_sum + s_b[tid + 2];
}
#if defined(__HIP_PLATFORM_HCC__) && HIP_VERSION > 305
cta.sync();
#else
cg::sync(cta);
#endif
if ((blockSize >= 2) && (tid < 1)) {
s_a[tid] = a_sum = a_sum + s_a[tid + 1];
s_b[tid] = b_sum = b_sum + s_b[tid + 1];
}
#if defined(__HIP_PLATFORM_HCC__) && HIP_VERSION > 305
cta.sync();
#else
cg::sync(cta);
#endif
#endif
// write result for this block to global mem
if (tid == 0) {
g_a[blockIdx.x] = (T)a_sum;
g_b[blockIdx.x] = (T)b_sum;
}
}
template <typename T, int blockSize>
__device__ void reduce_two_vectors_in_register(T a, T b, T* g_a, T* g_b)
{
const int threadIdInBlock = cg::this_thread_block().thread_rank();
T* s_a = SharedMemory<T>();
T* s_b = SharedMemory<T>() + cg::this_thread_block().size();
s_a[threadIdInBlock] = a;
s_b[threadIdInBlock] = b;
reduce_block_in_shared_memory<T, blockSize>(s_a, s_b, g_a, g_b);
}
template <typename T, typename GRAD_T, int blockSize>
__global__ void lamb_cuda_kernel_part1(
T* __restrict__ p,
GRAD_T* __restrict__ p_copy, // For mixed precision training, pass NULL if not needed
T* __restrict__ m,
T* __restrict__ v,
const GRAD_T* __restrict__ g,
const float b1,
const float b2,
const float eps,
const float grad_scale,
const float step_size,
const size_t tsize,
adamMode_t mode,
const float decay,
T* __restrict__ w_l2_i,
T* __restrict__ u_l2_i)
{
// Assuming 2D grids and 2D blocks
const int blockId = gridDim.x * blockIdx.y + blockIdx.x;
const int threadsPerBlock = blockDim.x * blockDim.y;
const int threadIdInBlock = cg::this_thread_block().thread_rank();
const int i = (blockId * threadsPerBlock + threadIdInBlock);
const int totThreads = gridDim.x * gridDim.y * threadsPerBlock;
T reg_w = 0;
T reg_u = 0;
for (int j = i; j < tsize; j += totThreads) {
T scaled_grad = g[j] / grad_scale;
T pj = p[j];
m[j] = b1 * m[j] + (1 - b1) * scaled_grad;
v[j] = b2 * v[j] + (1 - b2) * scaled_grad * scaled_grad;
float denom;
if (mode == ADAM_MODE_0)
denom = sqrtf(v[j] + eps);
else // Mode 1
denom = sqrtf(v[j]) + eps;
T update = (m[j] / denom) + (decay * p[j]);
reg_u += update * update;
reg_w += pj * pj;
}
reduce_two_vectors_in_register<T, blockSize>(reg_w, reg_u, w_l2_i, u_l2_i);
}
template <typename T, typename GRAD_T, int blockSize>
__global__ void lamb_cuda_kernel_part2(const size_t tsize, T* __restrict__ g_a, T* __restrict__ g_b)
{
T* s_a = SharedMemory<T>();
T* s_b = SharedMemory<T>() + cg::this_thread_block().size();
const int threadIdInBlock = cg::this_thread_block().thread_rank();
s_a[threadIdInBlock] = g_a[threadIdInBlock];
s_b[threadIdInBlock] = g_b[threadIdInBlock];
if (threadIdInBlock >= tsize) {
s_a[threadIdInBlock] = 0.0;
s_b[threadIdInBlock] = 0.0;
}
reduce_block_in_shared_memory<T, blockSize>(s_a, s_b, g_a, g_b);
}
template <typename T, typename GRAD_T>
__global__ void lamb_cuda_kernel_part3(
T* __restrict__ p,
GRAD_T* __restrict__ p_copy, // For mixed precision training, pass NULL if not needed
T* __restrict__ m,
T* __restrict__ v,
const GRAD_T* __restrict__ g,
const float b1,
const float b2,
const float max_coeff,
const float min_coeff,
const float eps,
const float grad_scale,
const float step_size,
const size_t tsize,
adamMode_t mode,
const float decay,
T* __restrict__ w_l2_i,
T* __restrict__ u_l2_i,
T* __restrict__ lamb_coeff_val)
{
// Assuming 2D grids and 2D blocks
const int blockId = gridDim.x * blockIdx.y + blockIdx.x;
const int threadsPerBlock = blockDim.x * blockDim.y;
const int threadIdInBlock = cg::this_thread_block().thread_rank();
const int i = (blockId * threadsPerBlock + threadIdInBlock);
const int totThreads = gridDim.x * gridDim.y * threadsPerBlock;
T reg_w = sqrtf(w_l2_i[0]);
T reg_u = sqrtf(u_l2_i[0]);
float lamb_coeff = 1.0;
if (reg_w != 0 and reg_u != 0) {
lamb_coeff = reg_w / reg_u;
if (lamb_coeff > max_coeff) { lamb_coeff = max_coeff; }
if (lamb_coeff < min_coeff) { lamb_coeff = min_coeff; }
}
if (blockId == 0 and threadIdInBlock == 0) {
lamb_coeff_val[0] = lamb_coeff;
// printf("Cuda Lamb Coeff is %.6f \n",lamb_coeff);
}
for (int j = i; j < tsize; j += totThreads) {
T pj = (float)p[j];
T mj = m[j];
T vj = v[j];
float denom;
if (mode == ADAM_MODE_0)
denom = sqrtf(vj + eps);
else // Mode 1
denom = sqrtf(vj) + eps;
T update = (mj / denom) + (decay * pj);
pj = pj - (step_size * lamb_coeff * update);
p[j] = pj;
if (p_copy != NULL) p_copy[j] = (GRAD_T)pj;
}
}
void fused_lamb_cuda(at::Tensor& p,
at::Tensor& p_copy,
at::Tensor& m,
at::Tensor& v,
at::Tensor& g,
float lr,
float beta1,
float beta2,
float max_coeff,
float min_coeff,
float eps,
float grad_scale,
int step,
int mode,
int bias_correction,
float decay,
at::Tensor& w_l2_i,
at::Tensor& u_l2_i,
at::Tensor& lamb_coeff)
{
// using namespace at;
// Get tensor size
int tsize = p.numel();
// Determine #threads and #blocks
const int threadsPerBlock = 512;
int num_blocks = (tsize + threadsPerBlock - 1) / threadsPerBlock;
if (num_blocks > 512) num_blocks = 512;
int smemsize = 0;
if (p.type().scalarType() == at::ScalarType::Double)
smemsize = 2 * threadsPerBlock * sizeof(double);
else
smemsize = 2 * threadsPerBlock * sizeof(float);
const dim3 blocks(num_blocks);
const dim3 threads(threadsPerBlock);
AT_ASSERTM(at::cuda::detail::canUse32BitIndexMath(p),
"parameter tensor is too large to be indexed with int32");
// Constants
float step_size = 0;
if (bias_correction == 1) {
const float bias_correction1 = 1 - ::pow(beta1, step);
const float bias_correction2 = 1 - ::pow(beta2, step);
step_size = lr * std::sqrt(bias_correction2) / bias_correction1;
} else {
step_size = lr;
}
hipStream_t stream = at::hip::getCurrentHIPStreamMasqueradingAsCUDA();
if (g.type().scalarType() == at::ScalarType::Half) {
// all other values should be fp32 for half gradients
AT_ASSERTM(p.type().scalarType() == at::ScalarType::Float,
"expected parameter to be of float type");
// dispatch is done on the gradient type
using namespace at; // prevents "toString is undefined" errors
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
g.scalar_type(), "lamb_cuda_kernel", ([&] {
using accscalar_t = at::acc_type<scalar_t, true>;
hipLaunchKernelGGL(( lamb_cuda_kernel_part1<accscalar_t, scalar_t, threadsPerBlock>)
, dim3(blocks), dim3(threadsPerBlock), smemsize, stream,
p.data<accscalar_t>(),
p_copy.numel() ? p_copy.data<scalar_t>() : NULL,
m.data<accscalar_t>(),
v.data<accscalar_t>(),
g.data<scalar_t>(),
beta1,
beta2,
eps,
grad_scale,
step_size,
tsize,
(adamMode_t)mode,
decay,
w_l2_i.data<accscalar_t>(),
u_l2_i.data<accscalar_t>());
hipLaunchKernelGGL(( lamb_cuda_kernel_part2<accscalar_t, scalar_t, threadsPerBlock>)
, dim3(1), dim3(threadsPerBlock), smemsize, stream,
num_blocks, w_l2_i.data<accscalar_t>(), u_l2_i.data<accscalar_t>());
hipLaunchKernelGGL(( lamb_cuda_kernel_part3<accscalar_t, scalar_t>)
, dim3(blocks), dim3(threadsPerBlock), smemsize, stream,
p.data<accscalar_t>(),
p_copy.numel() ? p_copy.data<scalar_t>() : NULL,
m.data<accscalar_t>(),
v.data<accscalar_t>(),
g.data<scalar_t>(),
beta1,
beta2,
max_coeff,
min_coeff,
eps,
grad_scale,
step_size,
tsize,
(adamMode_t)mode,
decay,
w_l2_i.data<accscalar_t>(),
u_l2_i.data<accscalar_t>(),
lamb_coeff.data<accscalar_t>());
}));
} else {
using namespace at;
AT_DISPATCH_FLOATING_TYPES(
g.scalar_type(), "lamb_cuda_kernel", ([&] {
hipLaunchKernelGGL(( lamb_cuda_kernel_part1<scalar_t, scalar_t, threadsPerBlock>)
, dim3(blocks), dim3(threadsPerBlock), smemsize, stream,
p.data<scalar_t>(),
NULL, // don't output p_copy for fp32, it's wasted write
m.data<scalar_t>(),
v.data<scalar_t>(),
g.data<scalar_t>(),
beta1,
beta2,
eps,
grad_scale,
step_size,
tsize,
(adamMode_t)mode,
decay,
w_l2_i.data<scalar_t>(),
u_l2_i.data<scalar_t>());
hipLaunchKernelGGL(( lamb_cuda_kernel_part2<scalar_t, scalar_t, threadsPerBlock>)
, dim3(1), dim3(threadsPerBlock), smemsize, stream,
num_blocks, w_l2_i.data<scalar_t>(), u_l2_i.data<scalar_t>());
hipLaunchKernelGGL(( lamb_cuda_kernel_part3<scalar_t, scalar_t>)
, dim3(blocks), dim3(threadsPerBlock), smemsize, stream,
p.data<scalar_t>(),
NULL, // don't output p_copy for fp32, it's wasted write
m.data<scalar_t>(),
v.data<scalar_t>(),
g.data<scalar_t>(),
beta1,
beta2,
max_coeff,
min_coeff,
eps,
grad_scale,
step_size,
tsize,
(adamMode_t)mode,
decay,
w_l2_i.data<scalar_t>(),
u_l2_i.data<scalar_t>(),
lamb_coeff.data<scalar_t>());
}));
}
THCudaCheck(hipGetLastError());
}
// template __device__ void reduce_two_vectors_in_register<float,512>(float a, float b, float* g_a,
// float* g_b, cg::grid_group &cgg);
// DeepSpeed note, code taken & adapted from commit 9aa94789f13ada713af36cfd8cca2fc9a7f6b79a
// https://github.com/ptillet/torch-blocksparse/blob/master/csrc/utils.cpp
#include <torch/extension.h>
#include <string>
#include <tuple>
#include <vector>
#ifdef _OPENMP
#include <omp.h>
#endif
typedef std::vector<std::tuple<int, torch::Tensor>> ret_t;
void segment_blocks(torch::Tensor layout,
torch::Tensor idx,
torch::Tensor scratch,
int max_width,
ret_t& ret)
{
size_t H = layout.size(0);
size_t M = layout.size(1);
size_t N = layout.size(2);
torch::Tensor tmp = torch::zeros_like(layout);
auto _tmp = tmp.accessor<int, 3>();
auto _layout = layout.accessor<int, 3>();
auto _idx = idx.accessor<int, 3>();
auto _scratch = scratch.accessor<int, 3>();
std::vector<int> current(H, 0);
#ifdef _OPENMP
#pragma omp parallel for
#endif
for (size_t h = 0; h < H; h++) {
// surrounding indices
std::vector<int> ii_left(max_width, -1);
std::vector<std::vector<int>> ii_top(max_width, std::vector<int>(N, -1));
for (size_t m = 0; m < M; m++) {
for (size_t n = 0; n < N; n++) {
int v = _layout[h][m][n];
if (v == 0) continue;
int n_left = ii_left[max_width - 1];
int m_top = ii_top[max_width - 1][n];
int top = (m_top >= 0) ? _tmp[h][m_top][n] : 0;
int left = (n_left >= 0) ? _tmp[h][m][n_left] : 0;
int topleft = (m_top >= 0 && n_left >= 0) ? _tmp[h][m_top][n_left] : 0;
int width = std::min(left, std::min(top, topleft)) + 1;
// reset width if blocks cannot be
// packed together (i.e., there's a 1 "in the middle")
for (int nn = n_left + 1; nn < n; nn++)
if (ii_top[max_width - 1][nn] > ii_top[max_width - 1][n]) width = 1;
_tmp[h][m][n] = width;
// update n_left ring buffer
for (int k = 0; k < max_width - 1; k++) ii_left[k] = ii_left[k + 1];
ii_left[max_width - 1] = n;
// update ii_top ring buffer
for (int k = 0; k < max_width - 1; k++) ii_top[k][n] = ii_top[k + 1][n];
ii_top[max_width - 1][n] = m;
// block is too small -- skip
if (width != max_width) continue;
// retained blocks are set to zeros
for (size_t km = 0; km < max_width; km++)
for (size_t kn = 0; kn < max_width; kn++) {
int mm = ii_top[km][n];
int nn = ii_left[kn];
if (mm < 0 || nn < 0) continue;
_layout[h][mm][nn] = 0;
_tmp[h][mm][nn] = 0;
_scratch[h][current[h]][0] = (int)h;
_scratch[h][current[h]][1] = (int)mm;
_scratch[h][current[h]][2] = (int)nn;
_scratch[h][current[h]][3] = _idx[h][mm][nn];
current[h]++;
}
}
}
}
std::vector<torch::Tensor> to_cat;
for (size_t h = 0; h < H; h++)
if (current[h] > 0) to_cat.push_back(scratch[h].slice(0, 0, current[h]));
if (!to_cat.empty()) ret.push_back({max_width, torch::cat(to_cat)});
}
ret_t sdd_segment(torch::Tensor layout, int start_width)
{
ret_t ret;
// block index
torch::Tensor idx = torch::zeros_like(layout);
int current = 0;
size_t H = layout.size(0);
size_t M = layout.size(1);
size_t N = layout.size(2);
auto _layout = layout.accessor<int, 3>();
auto _idx = idx.accessor<int, 3>();
for (size_t h = 0; h < H; h++)
for (size_t m = 0; m < M; m++)
for (size_t n = 0; n < N; n++) {
if (_layout[h][m][n] == 0) continue;
_idx[h][m][n] = current++;
}
// scratch memory
//torch::Tensor scratch = torch::empty({H, layout.sum().item<int>(), 4}, layout.dtype());
//aiss debug
torch::Tensor scratch = torch::empty({(long)H, layout.sum().item<int>(), 4}, layout.dtype());
for (int max_width = start_width; max_width > 0; max_width /= 2)
segment_blocks(layout, idx, scratch, max_width, ret);
return ret;
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
{
m.def("sdd_segment", &sdd_segment, "SDD segmentation handler");
}
#include "hip/cublas_wrappers.h"
int cublas_gemm_ex(rocblas_handle handle,
rocblas_operation transa,
rocblas_operation transb,
int m,
int n,
int k,
const float* alpha,
const float* beta,
const float* A,
const float* B,
float* C,
//cublasGemmAlgo_t algo)
rocblas_gemm_algo algo)
{
rocblas_status status = rocblas_gemm_ex(handle,
transa,
transb,
m,
n,
k,
(const void*)alpha,
(const void*)A,
rocblas_datatype_f32_r,
(transa == rocblas_operation_none) ? m : k,
(const void*)B,
rocblas_datatype_f32_r,
(transb == rocblas_operation_none) ? k : n,
(const void*)beta,
C,
rocblas_datatype_f32_r,
m,
C,
rocblas_datatype_f32_r,
m,
rocblas_datatype_f32_r,
algo,
0,
0);
if (status != rocblas_status_success) {
fprintf(stderr,
"!!!! kernel execution error. (m: %d, n: %d, k: %d, error: %d) \n",
m,
n,
k,
(int)status);
return EXIT_FAILURE;
}
return 0;
}
int cublas_gemm_ex(rocblas_handle handle,
rocblas_operation transa,
rocblas_operation transb,
int m,
int n,
int k,
const float* alpha,
const float* beta,
const __half* A,
const __half* B,
__half* C,
//cublasGemmAlgo_t algo)
rocblas_gemm_algo algo)
{
rocblas_status status = rocblas_gemm_ex(handle,
transa,
transb,
m,
n,
k,
(const void*)alpha,
(const void*)A,
rocblas_datatype_f16_r,
(transa == rocblas_operation_none) ? m : k,
(const void*)B,
rocblas_datatype_f16_r,
(transb == rocblas_operation_none) ? k : n,
(const void*)beta,
(void*)C,
rocblas_datatype_f16_r,
m,
(void*)C,
rocblas_datatype_f16_r,
m,
rocblas_datatype_f16_r,
algo,
0,
0);
if (status != rocblas_status_success) {
fprintf(stderr,
"!!!! kernel execution error. (m: %d, n: %d, k: %d, error: %d) \n",
m,
n,
k,
(int)status);
return EXIT_FAILURE;
}
return 0;
}
int cublas_strided_batched_gemm(rocblas_handle handle,
int m,
int n,
int k,
const float* alpha,
const float* beta,
const float* A,
const float* B,
float* C,
rocblas_operation op_A,
rocblas_operation op_B,
int stride_A,
int stride_B,
int stride_C,
int batch,
rocblas_gemm_algo algo)
{
rocblas_status status = rocblas_gemm_strided_batched_ex(handle,
op_A,
op_B,
m,
n,
k,
alpha,
A,
rocblas_datatype_f32_r,
(op_A == rocblas_operation_none) ? m : k,
stride_A,
B,
rocblas_datatype_f32_r,
(op_B == rocblas_operation_none) ? k : n,
stride_B,
beta,
C,
rocblas_datatype_f32_r,
m,
stride_C,
C,
rocblas_datatype_f32_r,
m,
stride_C,
batch,
rocblas_datatype_f32_r,
algo,
0,
0);
if (status != rocblas_status_success) {
fprintf(stderr,
"!!!! kernel execution error. (batch: %d, m: %d, n: %d, k: %d, error: %d) \n",
batch,
m,
n,
k,
(int)status);
return EXIT_FAILURE;
}
return 0;
}
int cublas_strided_batched_gemm(rocblas_handle handle,
int m,
int n,
int k,
const float* alpha,
const float* beta,
const __half* A,
const __half* B,
__half* C,
rocblas_operation op_A,
rocblas_operation op_B,
int stride_A,
int stride_B,
int stride_C,
int batch,
rocblas_gemm_algo algo)
{
rocblas_status status = rocblas_gemm_strided_batched_ex(handle,
op_A,
op_B,
m,
n,
k,
alpha,
A,
rocblas_datatype_f16_r,
(op_A == rocblas_operation_none) ? m : k,
stride_A,
B,
rocblas_datatype_f16_r,
(op_B == rocblas_operation_none) ? k : n,
stride_B,
beta,
C,
rocblas_datatype_f16_r,
m,
stride_C,
C,
rocblas_datatype_f16_r,
m,
stride_C,
batch,
rocblas_datatype_f16_r,
algo,
0,
0);
if (status != rocblas_status_success) {
fprintf(stderr,
"!!!! kernel execution error. (m: %d, n: %d, k: %d, error: %d) \n",
m,
n,
k,
(int)status);
return EXIT_FAILURE;
}
return 0;
}
#include "cublas_wrappers.h"
int cublas_gemm_ex(rocblas_handle handle,
rocblas_operation transa,
rocblas_operation transb,
int m,
int n,
int k,
const float* alpha,
const float* beta,
const float* A,
const float* B,
float* C,
cublasGemmAlgo_t algo)
{
rocblas_status status = rocblas_gemmex(handle,
transa,
transb,
m,
n,
k,
(const void*)alpha,
(const void*)A,
hipR32F,
(transa == rocblas_operation_none) ? m : k,
(const void*)B,
hipR32F,
(transb == rocblas_operation_none) ? k : n,
(const void*)beta,
C,
hipR32F,
m,
hipR32F,
algo);
if (status != rocblas_status_success) {
fprintf(stderr,
"!!!! kernel execution error. (m: %d, n: %d, k: %d, error: %d) \n",
m,
n,
k,
(int)status);
return EXIT_FAILURE;
}
return 0;
}
int cublas_gemm_ex(rocblas_handle handle,
rocblas_operation transa,
rocblas_operation transb,
int m,
int n,
int k,
const float* alpha,
const float* beta,
const __half* A,
const __half* B,
__half* C,
cublasGemmAlgo_t algo)
{
rocblas_status status = rocblas_gemmex(handle,
transa,
transb,
m,
n,
k,
(const void*)alpha,
(const void*)A,
hipR16F,
(transa == rocblas_operation_none) ? m : k,
(const void*)B,
hipR16F,
(transb == rocblas_operation_none) ? k : n,
(const void*)beta,
(void*)C,
hipR16F,
m,
hipR32F,
algo);
if (status != rocblas_status_success) {
fprintf(stderr,
"!!!! kernel execution error. (m: %d, n: %d, k: %d, error: %d) \n",
m,
n,
k,
(int)status);
return EXIT_FAILURE;
}
return 0;
}
int cublas_strided_batched_gemm(rocblas_handle handle,
int m,
int n,
int k,
const float* alpha,
const float* beta,
const float* A,
const float* B,
float* C,
rocblas_operation op_A,
rocblas_operation op_B,
int stride_A,
int stride_B,
int stride_C,
int batch,
cublasGemmAlgo_t algo)
{
rocblas_status status = cublasGemmStridedBatchedEx(handle,
op_A,
op_B,
m,
n,
k,
alpha,
A,
hipR32F,
(op_A == rocblas_operation_none) ? m : k,
stride_A,
B,
hipR32F,
(op_B == rocblas_operation_none) ? k : n,
stride_B,
beta,
C,
hipR32F,
m,
stride_C,
batch,
hipR32F,
algo);
if (status != rocblas_status_success) {
fprintf(stderr,
"!!!! kernel execution error. (batch: %d, m: %d, n: %d, k: %d, error: %d) \n",
batch,
m,
n,
k,
(int)status);
return EXIT_FAILURE;
}
return 0;
}
int cublas_strided_batched_gemm(rocblas_handle handle,
int m,
int n,
int k,
const float* alpha,
const float* beta,
const __half* A,
const __half* B,
__half* C,
rocblas_operation op_A,
rocblas_operation op_B,
int stride_A,
int stride_B,
int stride_C,
int batch,
cublasGemmAlgo_t algo)
{
rocblas_status status = cublasGemmStridedBatchedEx(handle,
op_A,
op_B,
m,
n,
k,
alpha,
A,
hipR16F,
(op_A == rocblas_operation_none) ? m : k,
stride_A,
B,
hipR16F,
(op_B == rocblas_operation_none) ? k : n,
stride_B,
beta,
C,
hipR16F,
m,
stride_C,
batch,
hipR32F,
algo);
if (status != rocblas_status_success) {
fprintf(stderr,
"!!!! kernel execution error. (m: %d, n: %d, k: %d, error: %d) \n",
m,
n,
k,
(int)status);
return EXIT_FAILURE;
}
return 0;
}
#include "hip/hip_runtime.h"
#include "hip/custom_hip_layers.h"
const int unroll_factor = 4;
__global__ void dropout_kernel(const int N,
const float ratio,
float* out,
const float* Xdata,
uint8_t* mask,
std::pair<uint64_t, uint64_t> seed)
{
const float scale = 1. / (1. - ratio);
int idx = blockIdx.x * blockDim.x + threadIdx.x;
hiprandStatePhilox4_32_10_t state;
hiprand_init(seed.first, idx, seed.second, &state);
CUDA_1D_KERNEL_LOOP(j, N / unroll_factor)
{
float4 rand = hiprand_uniform4(&state);
uint8_t m[unroll_factor];
m[0] = (uint8_t)(rand.x > ratio);
m[1] = (uint8_t)(rand.y > ratio);
m[2] = (uint8_t)(rand.z > ratio);
m[3] = (uint8_t)(rand.w > ratio);
int i = j * unroll_factor;
mask[i] = (uint8_t)m[0];
mask[i + 1] = (uint8_t)m[1];
mask[i + 2] = (uint8_t)m[2];
mask[i + 3] = (uint8_t)m[3];
out[i] = Xdata[i] * scale * m[0];
out[i + 1] = Xdata[i + 1] * scale * m[1];
out[i + 2] = Xdata[i + 2] * scale * m[2];
out[i + 3] = Xdata[i + 3] * scale * m[3];
}
int high_index =
((((N / unroll_factor) - 1) / blockDim.x + 1) * (unroll_factor * blockDim.x)) + threadIdx.x;
if (N > high_index) {
float4 rand = hiprand_uniform4(&state);
float* rand_data = &(rand.x);
int k = 0;
for (int i = high_index; i < N; i++) {
uint8_t m = (uint8_t)(rand_data[k++] > ratio);
out[i] = Xdata[i] * scale * m;
mask[i] = m;
}
}
}
__global__ void dropout_kernel(const int N,
const float ratio,
__half* out,
const __half* Xdata,
uint8_t* mask,
std::pair<uint64_t, uint64_t> seed)
{
const float scale = 1. / (1. - ratio);
int idx = blockIdx.x * blockDim.x + threadIdx.x;
hiprandStatePhilox4_32_10_t state;
hiprand_init(seed.first, idx, seed.second, &state);
#ifdef __STOCHASTIC_MODE__
const __half2 h_scale = __float2half2_rn(scale);
const float2* x_cast = reinterpret_cast<const float2*>(Xdata);
float2* out_cast = reinterpret_cast<float2*>(out);
uint32_t* mask_cast = reinterpret_cast<uint32_t*>(mask);
uint32_t m_32;
uint8_t* m = reinterpret_cast<uint8_t*>(&m_32);
float2 result_f;
__half2* result_h = reinterpret_cast<__half2*>(&result_f);
__half2 mask_h[2];
float2 mask_f[2];
CUDA_1D_KERNEL_LOOP(j, N / unroll_factor)
{
float2 x_f = x_cast[j];
__half2* x_h = reinterpret_cast<__half2*>(&x_f);
float4 rand = hiprand_uniform4(&state);
m[0] = (uint8_t)(rand.x > ratio);
m[1] = (uint8_t)(rand.y > ratio);
m[2] = (uint8_t)(rand.z > ratio);
m[3] = (uint8_t)(rand.w > ratio);
float* mask_f_data = &mask_f[0].x;
#pragma unroll
for (int i = 0; i < unroll_factor; i++) mask_f_data[i] = (float)(m[i]);
mask_h[0] = __float22half2_rn(mask_f[0]);
mask_h[1] = __float22half2_rn(mask_f[1]);
result_h[0] = x_h[0] * h_scale * mask_h[0];
result_h[1] = x_h[1] * h_scale * mask_h[1];
out_cast[j] = result_f;
mask_cast[j] = m_32;
}
#else
CUDA_1D_KERNEL_LOOP(j, N / unroll_factor)
{
int i = j * unroll_factor;
const __half2* vals_half = reinterpret_cast<const __half2*>(Xdata + i);
float2 vals_half_f[2];
vals_half_f[0] = __half22float2(vals_half[0]);
vals_half_f[1] = __half22float2(vals_half[1]);
uint8_t m[unroll_factor];
float4 rand = hiprand_uniform4(&state);
m[0] = (uint8_t)(rand.x > ratio);
m[1] = (uint8_t)(rand.y > ratio);
m[2] = (uint8_t)(rand.z > ratio);
m[3] = (uint8_t)(rand.w > ratio);
out[i] = __float2half(vals_half_f[0].x * scale * m[0]);
out[i + 1] = __float2half(vals_half_f[0].y * scale * m[1]);
out[i + 2] = __float2half(vals_half_f[1].x * scale * m[2]);
out[i + 3] = __float2half(vals_half_f[1].y * scale * m[3]);
mask[i] = m[0];
mask[i + 1] = m[1];
mask[i + 2] = m[2];
mask[i + 3] = m[3];
}
#endif
int high_index =
((((N / unroll_factor) - 1) / blockDim.x + 1) * (unroll_factor * blockDim.x)) + threadIdx.x;
if (N > high_index) {
float4 rand = hiprand_uniform4(&state);
float* rand_data = &(rand.x);
int k = 0;
for (int i = high_index; i < N; i++) {
uint8_t m = (uint8_t)(rand_data[k++] > ratio);
out[i] = __float2half((float)Xdata[i] * scale * m);
mask[i] = m;
}
}
}
__global__ void dropout_kernel_bwd(const int N,
const float ratio,
const float* Xdata,
float* out,
uint8_t* mask,
std::pair<uint64_t, uint64_t> seed)
{
const float scale = 1. / (1. - ratio);
CUDA_1D_KERNEL_LOOP(j, N / unroll_factor)
{
int i = j * unroll_factor;
out[i] = mask[i] ? Xdata[i] * scale : 0.0;
out[i + 1] = mask[i + 1] ? Xdata[i + 1] * scale : 0.0;
out[i + 2] = mask[i + 2] ? Xdata[i + 2] * scale : 0.0;
out[i + 3] = mask[i + 3] ? Xdata[i + 3] * scale : 0.0;
}
int high_index =
((((N / unroll_factor) - 1) / blockDim.x + 1) * (unroll_factor * blockDim.x)) + threadIdx.x;
if (N > high_index) {
for (int i = high_index; i < N; i++) { out[i] = mask[i] ? Xdata[i] * scale : 0.0; }
}
}
__global__ void dropout_kernel_bwd(const int N,
const float ratio,
const __half* Xdata,
__half* out,
uint8_t* mask,
std::pair<uint64_t, uint64_t> seed)
{
const float scale = 1. / (1. - ratio);
#ifdef __STOCHASTIC_MODE__
const __half2 h_scale = __float2half2_rn(scale);
const float2* x_cast = reinterpret_cast<const float2*>(Xdata);
float2* out_cast = reinterpret_cast<float2*>(out);
uint32_t* mask_cast = reinterpret_cast<uint32_t*>(mask);
CUDA_1D_KERNEL_LOOP(j, N / unroll_factor)
{
float2 x_f = x_cast[j];
__half2* x_h = reinterpret_cast<__half2*>(&x_f);
uint32_t m_32 = mask_cast[j];
uint8_t* m = (uint8_t*)&m_32;
__half2 mask_h[2];
float2 mask_f[2];
float* mask_f_data = &mask_f[0].x;
#pragma unroll
for (int i = 0; i < unroll_factor; i++) mask_f_data[i] = (float)(m[i]);
#pragma unroll
for (int i = 0; i < 2; i++) mask_h[i] = __float22half2_rn(mask_f[i]);
float2 result_f;
__half2* result_h = reinterpret_cast<__half2*>(&result_f);
result_h[0] = x_h[0] * h_scale * mask_h[0];
result_h[1] = x_h[1] * h_scale * mask_h[1];
out_cast[j] = result_f;
}
#else
const __half h_scale = __float2half(scale);
const __half h_zero = __float2half(0.0);
CUDA_1D_KERNEL_LOOP(j, N / unroll_factor)
{
int i = j * unroll_factor;
const __half2* vals_half = reinterpret_cast<const __half2*>(Xdata + i);
uint8_t* m = mask + i;
float2 vals_half_f[2];
vals_half_f[0] = __half22float2(vals_half[0]);
vals_half_f[1] = __half22float2(vals_half[1]);
out[i] = __float2half(vals_half_f[0].x * scale * m[0]);
out[i + 1] = __float2half(vals_half_f[0].y * scale * m[1]);
out[i + 2] = __float2half(vals_half_f[1].x * scale * m[2]);
out[i + 3] = __float2half(vals_half_f[1].y * scale * m[3]);
}
#endif
int high_index =
((((N / unroll_factor) - 1) / blockDim.x + 1) * (unroll_factor * blockDim.x)) + threadIdx.x;
if (N > high_index) {
for (int i = high_index; i < N; i++) {
out[i] = __float2half((float)Xdata[i] * scale * mask[i]);
}
}
}
template <typename T>
void launch_dropout(T* out,
const T* vals,
uint8_t* mask,
int total_count,
int dim,
float ratio,
hipStream_t stream,
bool bwd)
{
assert(unroll_factor == 4);
dim3 grid_dim = DS_GET_BLOCKS(total_count / unroll_factor);
dim3 block_dim = DS_CUDA_NUM_THREADS;
if (dim > 512) {
block_dim.x >>= 1;
grid_dim.x <<= 1;
}
uint64_t inc = total_count / grid_dim.x / block_dim.x;
std::pair<uint64_t, uint64_t> seed = Context::Instance().IncrementOffset(inc);
if (bwd)
hipLaunchKernelGGL(( dropout_kernel_bwd), dim3(grid_dim), dim3(block_dim), 0, stream,
total_count, ratio, vals, out, mask, seed);
else
hipLaunchKernelGGL(( dropout_kernel), dim3(grid_dim), dim3(block_dim), 0, stream,
total_count, ratio, out, vals, mask, seed);
}
template void launch_dropout(float* out,
const float* vals,
uint8_t* mask,
int total_count,
int dim,
float ratio,
hipStream_t stream,
bool);
template void launch_dropout(__half* out,
const __half* vals,
uint8_t* mask,
int total_count,
int dim,
float ratio,
hipStream_t stream,
bool);
__global__ void dropout_grad_kernel(const int N, const float scale, float* Xdata, uint8_t* mask)
{
CUDA_1D_KERNEL_LOOP(i, N) { Xdata[i] *= scale * mask[i]; }
}
__global__ void dropout_grad_kernel(const int N, const float scale, __half* Xdata, uint8_t* mask)
{
const __half2 h_scale = __float2half2_rn(scale);
float2* x_cast = reinterpret_cast<float2*>(Xdata);
uint32_t* mask_cast = reinterpret_cast<uint32_t*>(mask);
CUDA_1D_KERNEL_LOOP(j, N / unroll_factor)
{
float2 x_data = x_cast[j];
uint32_t m_32 = mask_cast[j];
uint8_t* m = (uint8_t*)&m_32;
float2 result_f;
__half2* result_h = reinterpret_cast<__half2*>(&result_f);
#ifdef __STOCHASTIC_MODE__
__half2* x_data_h = reinterpret_cast<__half2*>(&x_data);
__half2 mask_h[2];
float2 mask_f[2];
float* mask_f_data = &mask_f[0].x;
#pragma unroll
for (int i = 0; i < unroll_factor; i++) *(mask_f_data++) = (float)(m[i]);
mask_h[0] = __float22half2_rn(mask_f[0]);
mask_h[1] = __float22half2_rn(mask_f[1]);
result_h[0] = x_data_h[0] * h_scale * mask_h[0];
result_h[1] = x_data_h[1] * h_scale * mask_h[1];
#else
__half* x_data_h = reinterpret_cast<__half*>(&x_data);
float2 result[2];
result[0].x = (float)x_data_h[0] * scale * m[0];
result[0].y = (float)x_data_h[1] * scale * m[1];
result[1].x = (float)x_data_h[2] * scale * m[2];
result[1].y = (float)x_data_h[3] * scale * m[3];
result_h[0] = __float22half2_rn(result[0]);
result_h[1] = __float22half2_rn(result[1]);
#endif
x_cast[j] = result_f;
}
int high_index =
((((N / unroll_factor) - 1) / blockDim.x + 1) * (unroll_factor * blockDim.x)) + threadIdx.x;
if (N > high_index) {
for (int i = high_index; i < N; i++) {
Xdata[i] = __float2half((float)Xdata[i] * scale * mask[i]);
}
}
}
template <typename T>
void launch_dropout_grad(T* vals, uint8_t* mask, int total_count, float ratio, hipStream_t stream)
{
assert(unroll_factor == 4);
const float scale = 1. / (1. - ratio);
hipLaunchKernelGGL(( dropout_grad_kernel), dim3(DS_GET_BLOCKS(total_count / unroll_factor)),
dim3(DS_CUDA_NUM_THREADS),
0,
stream, total_count, scale, vals, mask);
}
template void launch_dropout_grad(float* vals,
uint8_t* mask,
int total_count,
float ratio,
hipStream_t stream);
template void launch_dropout_grad(__half* vals,
uint8_t* mask,
int total_count,
float ratio,
hipStream_t stream);
__global__ void dropout_grad_kernel(const int N,
const float scale,
const float* Xdata,
float* out,
uint8_t* mask)
{
CUDA_1D_KERNEL_LOOP(i, N) { out[i] = Xdata[i] * scale * mask[i]; }
}
__global__ void dropout_grad_kernel(const int N,
const float scale,
const __half* Xdata,
__half* out,
uint8_t* mask)
{
const float2* x_cast = reinterpret_cast<const float2*>(Xdata);
float2* out_cast = reinterpret_cast<float2*>(out);
const uint32_t* mask_cast = reinterpret_cast<const uint32_t*>(mask);
float2 result_f;
__half2* result_h = reinterpret_cast<__half2*>(&result_f);
CUDA_1D_KERNEL_LOOP(j, N / unroll_factor)
{
float2 x_data = x_cast[j];
uint32_t m_32 = mask_cast[j];
uint8_t* m = (uint8_t*)&m_32;
__half* x_data_h = reinterpret_cast<__half*>(&x_data);
float2 result[2];
result[0].x = (float)x_data_h[0] * scale * m[0];
result[0].y = (float)x_data_h[1] * scale * m[1];
result[1].x = (float)x_data_h[2] * scale * m[2];
result[1].y = (float)x_data_h[3] * scale * m[3];
result_h[0] = __float22half2_rn(result[0]);
result_h[1] = __float22half2_rn(result[1]);
out_cast[j] = result_f;
}
int high_index =
((((N / unroll_factor) - 1) / blockDim.x + 1) * (unroll_factor * blockDim.x)) + threadIdx.x;
if (N > high_index) {
for (int i = high_index; i < N; i++) {
out[i] = __float2half((float)Xdata[i] * scale * mask[i]);
}
}
}
template <typename T>
void launch_dropout_grad(T* vals_out,
const T* vals,
uint8_t* mask,
int total_count,
float ratio,
hipStream_t stream)
{
assert(unroll_factor == 4);
const float scale = 1. / (1. - ratio);
hipLaunchKernelGGL(( dropout_grad_kernel), dim3(DS_GET_BLOCKS(total_count / unroll_factor)),
dim3(DS_CUDA_NUM_THREADS),
0,
stream, total_count, scale, vals, vals_out, mask);
}
template void launch_dropout_grad(float*,
const float* vals,
uint8_t* mask,
int total_count,
float ratio,
hipStream_t stream);
template void launch_dropout_grad(__half*,
const __half* vals,
uint8_t* mask,
int total_count,
float ratio,
hipStream_t stream);
__global__ void dropout_kernel(const int N,
const int dim,
const float ratio,
const float* bias,
float* Xdata,
uint8_t* mask,
std::pair<uint64_t, uint64_t> seed)
{
const float scale = 1. / (1. - ratio);
int idx = blockIdx.x * blockDim.x + threadIdx.x;
int tid = threadIdx.x % (dim / unroll_factor);
hiprandStatePhilox4_32_10_t state;
hiprand_init(seed.first, idx, seed.second, &state);
float4* Xdata_cast = reinterpret_cast<float4*>(Xdata);
uint32_t* mask_32 = reinterpret_cast<uint32_t*>(mask);
const float4* bias_cast = reinterpret_cast<const float4*>(bias);
CUDA_1D_KERNEL_LOOP(j, N)
{
float4 rand = hiprand_uniform4(&state);
uint32_t m_32;
uint8_t* m = (uint8_t*)&m_32;
m[0] = (uint8_t)(rand.x > ratio);
m[1] = (uint8_t)(rand.y > ratio);
m[2] = (uint8_t)(rand.z > ratio);
m[3] = (uint8_t)(rand.w > ratio);
float4 x_data = Xdata_cast[j];
float4 b_data = bias_cast[j % (dim / unroll_factor)];
x_data.x += b_data.x;
x_data.y += b_data.y;
x_data.z += b_data.z;
x_data.w += b_data.w;
x_data.x = x_data.x * scale * m[0];
x_data.y = x_data.y * scale * m[1];
x_data.z = x_data.z * scale * m[2];
x_data.w = x_data.w * scale * m[3];
mask_32[j] = m_32;
Xdata_cast[j] = x_data;
}
int high_index =
((((N / unroll_factor) - 1) / blockDim.x + 1) * (unroll_factor * blockDim.x)) + threadIdx.x;
if (N > high_index) {
float4 rand = hiprand_uniform4(&state);
float* rand_data = &(rand.x);
int k = 0;
for (int i = high_index; i < N; i++) {
float x_data = Xdata[i] + bias[i % dim];
uint8_t m = (uint8_t)(rand_data[k++] > ratio);
Xdata[i] = x_data * scale * m;
mask[i] = m;
}
}
}
__global__ void dropout_kernel(const int N,
const int dim,
const float ratio,
const __half* bias,
__half* Xdata,
uint8_t* mask,
std::pair<uint64_t, uint64_t> seed)
{
const float scale = 1. / (1. - ratio);
int idx = blockIdx.x * blockDim.x + threadIdx.x;
int tid = threadIdx.x % (dim / unroll_factor);
hiprandStatePhilox4_32_10_t state;
hiprand_init(seed.first, idx, seed.second, &state);
float2* Xdata_cast = reinterpret_cast<float2*>(Xdata);
uint32_t* mask_32 = reinterpret_cast<uint32_t*>(mask);
const float2* bias_cast = reinterpret_cast<const float2*>(bias);
CUDA_1D_KERNEL_LOOP(j, N)
{
float4 rand = hiprand_uniform4(&state);
float2 data_f;
__half2* data_h = reinterpret_cast<__half2*>(&data_f);
float2 bias_f;
__half2* bias_h = reinterpret_cast<__half2*>(&bias_f);
data_f = Xdata_cast[j];
bias_f = bias_cast[j % (dim / unroll_factor)];
float2 data_h_0 = __half22float2(data_h[0]);
float2 data_h_1 = __half22float2(data_h[1]);
float2 bias_h_0 = __half22float2(bias_h[0]);
float2 bias_h_1 = __half22float2(bias_h[1]);
data_h_0.x += bias_h_0.x;
data_h_0.y += bias_h_0.y;
data_h_1.x += bias_h_1.x;
data_h_1.y += bias_h_1.y;
uint32_t m_32;
uint8_t* m = (uint8_t*)&m_32;
m[0] = (uint8_t)(rand.x > ratio);
m[1] = (uint8_t)(rand.y > ratio);
m[2] = (uint8_t)(rand.z > ratio);
m[3] = (uint8_t)(rand.w > ratio);
data_h_0.x = __float2half(data_h_0.x * scale * m[0]);
data_h_0.y = __float2half(data_h_0.y * scale * m[1]);
data_h_1.x = __float2half(data_h_1.x * scale * m[2]);
data_h_1.y = __float2half(data_h_1.y * scale * m[3]);
float2 result_f;
__half2* result_h = reinterpret_cast<__half2*>(&result_f);
result_h[0] = __float22half2_rn(data_h_0);
result_h[1] = __float22half2_rn(data_h_1);
Xdata_cast[j] = result_f;
mask_32[j] = m_32;
}
int high_index =
((((N / unroll_factor) - 1) / blockDim.x + 1) * (unroll_factor * blockDim.x)) + threadIdx.x;
if (N > high_index) {
float4 rand = hiprand_uniform4(&state);
float* rand_data = &(rand.x);
int k = 0;
for (int i = high_index; i < N; i++) {
float x_data = (float)Xdata[i] + (float)bias[i % dim];
uint8_t m = (uint8_t)(rand_data[k++] > ratio);
Xdata[i] = __float2half(x_data * scale * m);
mask[i] = m;
}
}
}
template <typename T>
void launch_dropout(T* out,
const T* bias,
uint8_t* mask,
int batch,
int dim,
float ratio,
hipStream_t stream)
{
assert(unroll_factor == 4);
int total_count = batch * dim / unroll_factor;
dim3 grid_dim = DS_GET_BLOCKS(total_count);
dim3 block_dim = DS_CUDA_NUM_THREADS;
uint64_t inc = (batch * dim) / grid_dim.x / block_dim.x;
std::pair<uint64_t, uint64_t> seed = Context::Instance().IncrementOffset(inc);
hipLaunchKernelGGL(( dropout_kernel), dim3(grid_dim), dim3(block_dim), 0, stream,
total_count, dim, ratio, bias, out, mask, seed);
}
template void launch_dropout(float*,
const float* bias,
uint8_t* mask,
int batch,
int dim,
float ratio,
hipStream_t stream);
template void launch_dropout(__half*,
const __half* bias,
uint8_t* mask,
int batch,
int dim,
float ratio,
hipStream_t stream);
__global__ void dropout_kernel(const int N,
const int dim,
const float ratio,
const float* input,
const float* residual,
const float* bias,
float* out,
uint8_t* mask,
std::pair<uint64_t, uint64_t> seed)
{
const float scale = 1. / (1. - ratio);
int idx = blockIdx.x * blockDim.x + threadIdx.x;
int tid = threadIdx.x % (dim / unroll_factor);
hiprandStatePhilox4_32_10_t state;
hiprand_init(seed.first, idx, seed.second, &state);
float4* out_cast = reinterpret_cast<float4*>(out);
uint32_t* mask_32 = reinterpret_cast<uint32_t*>(mask);
const float4* bias_cast = reinterpret_cast<const float4*>(bias);
const float4* residual_cast = reinterpret_cast<const float4*>(residual);
const float4* input_cast = reinterpret_cast<const float4*>(input);
CUDA_1D_KERNEL_LOOP(j, N)
{
float4 rand = hiprand_uniform4(&state);
uint32_t m_32;
uint8_t* m = (uint8_t*)&m_32;
m[0] = (uint8_t)(rand.x > ratio);
m[1] = (uint8_t)(rand.y > ratio);
m[2] = (uint8_t)(rand.z > ratio);
m[3] = (uint8_t)(rand.w > ratio);
float4 out_data;
float4 b_data = bias_cast[j % (dim / unroll_factor)];
float4 res_data = residual_cast[j];
float4 inp_data = input_cast[j];
out_data.x = (b_data.x + inp_data.x);
out_data.y = (b_data.y + inp_data.y);
out_data.z = (b_data.z + inp_data.z);
out_data.w = (b_data.w + inp_data.w);
out_data.x = out_data.x * scale * m[0];
out_data.y = out_data.y * scale * m[1];
out_data.z = out_data.z * scale * m[2];
out_data.w = out_data.w * scale * m[3];
out_data.x += res_data.x;
out_data.y += res_data.y;
out_data.z += res_data.z;
out_data.w += res_data.w;
mask_32[j] = m_32;
out_cast[j] = out_data;
}
int high_index =
((((N / unroll_factor) - 1) / blockDim.x + 1) * (unroll_factor * blockDim.x)) + threadIdx.x;
if (N > high_index) {
float4 rand = hiprand_uniform4(&state);
float* rand_data = &(rand.x);
int k = 0;
for (int i = high_index; i < N; i++) {
float x_data = input[i] + bias[i % dim];
uint8_t m = (uint8_t)(rand_data[k++] > ratio);
x_data = x_data * scale * m;
x_data += residual[i];
out[i] = x_data;
mask[i] = m;
}
}
}
__global__ void dropout_kernel(const int N,
const int dim,
const float ratio,
const __half* input,
const __half* residual,
const __half* bias,
__half* out,
uint8_t* mask,
std::pair<uint64_t, uint64_t> seed)
{
const float scale = 1. / (1. - ratio);
int idx = blockIdx.x * blockDim.x + threadIdx.x;
int tid = threadIdx.x % (dim / unroll_factor);
hiprandStatePhilox4_32_10_t state;
hiprand_init(seed.first, idx, seed.second, &state);
float2* out_cast = reinterpret_cast<float2*>(out);
uint32_t* mask_32 = reinterpret_cast<uint32_t*>(mask);
const float2* bias_cast = reinterpret_cast<const float2*>(bias);
const float2* residual_cast = reinterpret_cast<const float2*>(residual);
const float2* input_cast = reinterpret_cast<const float2*>(input);
CUDA_1D_KERNEL_LOOP(j, N)
{
float4 rand = hiprand_uniform4(&state);
float2 data_f;
__half2* data_h = reinterpret_cast<__half2*>(&data_f);
float2 bias_f;
__half2* bias_h = reinterpret_cast<__half2*>(&bias_f);
float2 residual_f;
__half2* residual_h = reinterpret_cast<__half2*>(&residual_f);
float2 input_f;
__half2* input_h = reinterpret_cast<__half2*>(&input_f);
bias_f = bias_cast[j % (dim / unroll_factor)];
residual_f = residual_cast[j];
input_f = input_cast[j];
float2 data_h_0 = __half22float2(data_h[0]);
float2 data_h_1 = __half22float2(data_h[1]);
float2 bias_h_0 = __half22float2(bias_h[0]);
float2 bias_h_1 = __half22float2(bias_h[1]);
float2 residual_h_0 = __half22float2(residual_h[0]);
float2 residual_h_1 = __half22float2(residual_h[1]);
float2 input_h_0 = __half22float2(input_h[0]);
float2 input_h_1 = __half22float2(input_h[1]);
data_h_0.x = (bias_h_0.x + input_h_0.x);
data_h_0.y = (bias_h_0.y + input_h_0.y);
data_h_1.x = (bias_h_1.x + input_h_1.x);
data_h_1.y = (bias_h_1.y + input_h_1.y);
uint32_t m_32;
uint8_t* m = (uint8_t*)&m_32;
m[0] = (uint8_t)(rand.x > ratio);
m[1] = (uint8_t)(rand.y > ratio);
m[2] = (uint8_t)(rand.z > ratio);
m[3] = (uint8_t)(rand.w > ratio);
data_h_0.x = __float2half(data_h_0.x * scale * m[0]);
data_h_0.y = __float2half(data_h_0.y * scale * m[1]);
data_h_1.x = __float2half(data_h_1.x * scale * m[2]);
data_h_1.y = __float2half(data_h_1.y * scale * m[3]);
data_h_0.x += residual_h_0.x;
data_h_0.y += residual_h_0.y;
data_h_1.x += residual_h_1.x;
data_h_1.y += residual_h_1.y;
float2 result_f;
__half2* result_h = reinterpret_cast<__half2*>(&result_f);
result_h[0] = __float22half2_rn(data_h_0);
result_h[1] = __float22half2_rn(data_h_1);
out_cast[j] = result_f;
mask_32[j] = m_32;
}
int high_index =
((((N / unroll_factor) - 1) / blockDim.x + 1) * (unroll_factor * blockDim.x)) + threadIdx.x;
if (N > high_index) {
float4 rand = hiprand_uniform4(&state);
float* rand_data = &(rand.x);
int k = 0;
for (int i = high_index; i < N; i++) {
float x_data = (float)input[i] + (float)bias[i % dim];
uint8_t m = (uint8_t)(rand_data[k++] > ratio);
x_data = x_data * scale * m;
x_data += (float)residual[i];
out[i] = __float2half(x_data);
mask[i] = m;
}
}
}
template <typename T>
void launch_dropout(T* out,
const T* input,
const T* residual,
const T* bias,
uint8_t* mask,
int batch,
int dim,
float ratio,
hipStream_t stream)
{
assert(unroll_factor == 4);
int total_count = batch * dim / unroll_factor;
dim3 grid_dim = DS_GET_BLOCKS(total_count);
dim3 block_dim = DS_CUDA_NUM_THREADS;
uint64_t inc = (batch * dim) / grid_dim.x / block_dim.x;
std::pair<uint64_t, uint64_t> seed = Context::Instance().IncrementOffset(inc);
hipLaunchKernelGGL(( dropout_kernel), dim3(grid_dim), dim3(block_dim), 0, stream,
total_count, dim, ratio, input, residual, bias, out, mask, seed);
}
template void launch_dropout(float*,
const float*,
const float* residual,
const float* bias,
uint8_t* mask,
int batch,
int dim,
float ratio,
hipStream_t stream);
template void launch_dropout(__half*,
const __half*,
const __half* residual,
const __half* bias,
uint8_t* mask,
int batch,
int dim,
float ratio,
hipStream_t stream);
#include <torch/extension.h>
#include <rocblas.h>
#include <hip/hip_fp16.h>
#include <hip/hip_runtime.h>
#include <type_traits>
#include <unordered_map>
#include <vector>
#include "hip/Timer.h"
#include "hip/context.h"
#include "hip/cublas_wrappers.h"
#include "hip/custom_hip_layers.h"
#include "hip/ds_transformer_hip.h"
static std::unordered_map<int, std::shared_ptr<void>> s_transformer_layers;
const int init_seq_length = 128;
// C++ interface
template <typename T>
size_t get_workspace_size(int maxBatchSize,
int seq_len,
int hidden_size,
int intermediate_size,
int heads,
bool training,
bool gelu_checkpoint)
{
size_t workSpacesize = 4 * (size_t(maxBatchSize) * seq_len * hidden_size);
if (training) {
workSpacesize += ((std::max)((size_t(maxBatchSize) * seq_len * intermediate_size),
2 * (size_t(maxBatchSize) * heads * seq_len * seq_len)));
if (gelu_checkpoint)
workSpacesize += 2 * (size_t(maxBatchSize) * seq_len * intermediate_size);
}
return workSpacesize; // * sizeof(T);
}
// NOTE: AT_ASSERT has become AT_CHECK on master after 0.4.
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) \
CHECK_CUDA(x); \
CHECK_CONTIGUOUS(x)
template <typename T>
BertTransformerLayer<T>::BertTransformerLayer(int layer_id,
int batch_size,
int hidden_size,
int num_heads,
int intermediate_size,
int seq_length,
float attn_prob_dropout_ratio,
float hidden_output_dropout_ratio,
float layer_norm_eps,
bool pre_or_postLayerNorm,
const std::vector<std::array<int, 3>>& gemm_algos,
bool attn_dropout_checkpoint,
bool normalize_invertible,
bool gelu_checkpoint,
bool stochastic_mode)
: _layer_id(layer_id),
_batch_size(batch_size),
_hidden_size(hidden_size),
_heads(num_heads),
_intermediate_size(intermediate_size),
_seq_length(seq_length),
_training(true),
_pre_or_postLayerNorm(pre_or_postLayerNorm),
_attn_dropout_checkpoint(attn_dropout_checkpoint),
_normalize_invertible(normalize_invertible),
_gelu_checkpoint(gelu_checkpoint),
_stochastic_mode(stochastic_mode),
_stream(Context::Instance().GetCurrentStream()),
_cublasHandle(Context::Instance().GetCublasHandle()),
_qkv_linear(typename FeedForward<T>::Config(batch_size * seq_length,
3 * hidden_size,
hidden_size,
gemm_algos[0])),
_attn_out_linear(typename FeedForward<T>::Config(batch_size * seq_length,
hidden_size,
hidden_size,
gemm_algos[0])),
_attn_layer_norm(typename Normalize_Layer<T>::Config(batch_size,
seq_length,
hidden_size,
layer_norm_eps,
true,
!normalize_invertible)),
_layer_norm(typename Normalize_Layer<T>::Config(batch_size,
seq_length,
hidden_size,
layer_norm_eps,
true,
!normalize_invertible)),
_ff1(typename FeedForward<T>::Config(batch_size * seq_length,
_intermediate_size,
hidden_size,
gemm_algos[1])),
_ff2(typename FeedForward<T>::Config(batch_size * seq_length,
hidden_size,
_intermediate_size,
gemm_algos[2])),
_softmax(typename Softmax<T>::Config(batch_size, num_heads, seq_length)),
_gelu(typename Gelu<T>::Config(_intermediate_size)),
_attn_prob_dropout(typename Dropout<T>::Config(attn_prob_dropout_ratio, _seq_length)),
_attn_output_dropout(typename Dropout<T>::Config(hidden_output_dropout_ratio, _hidden_size)),
_layer_output_dropout(typename Dropout<T>::Config(hidden_output_dropout_ratio, _hidden_size)),
_attn_scores(typename StridedBatchGemm<T>::Config(_batch_size * _heads,
_seq_length,
_seq_length,
_hidden_size / _heads,
//(T(1.0) / T(sqrt(_hidden_size / _heads))),
(T(1.0 / (sqrt(_hidden_size / _heads)))),
T(0.0),
rocblas_operation_transpose,
rocblas_operation_none,
gemm_algos[3])),
_attn_context(typename StridedBatchGemm<T>::Config(_batch_size * _heads,
_hidden_size / _heads,
_seq_length,
_seq_length,
T(1.0),
T(0.0),
rocblas_operation_none,
rocblas_operation_none,
gemm_algos[4]))
{
assert(_hidden_size % _heads == 0);
Initialize();
}
template <typename T>
BertTransformerLayer<T>::~BertTransformerLayer()
{
}
template <typename T>
void BertTransformerLayer<T>::Initialize()
{
//aiss debug:rocm has no CUBLAS_TENSOR_OP_MATH
//if (std::is_same<T, __half>::value) rocblas_set_math_mode(_cublasHandle, CUBLAS_TENSOR_OP_MATH);
}
template <typename T>
void BertTransformerLayer<T>::Forward(int bsz,
const T* input_ptr,
const T* input_mask_ptr,
const T* attn_qkvw_ptr,
const T* attn_qkvb_ptr,
const T* attn_ow_ptr,
const T* attn_ob_ptr,
const T* attn_nw_ptr,
const T* attn_nb_ptr,
const T* inter_w_ptr,
const T* inter_b_ptr,
const T* output_w_ptr,
const T* output_b_ptr,
const T* norm_w_ptr,
const T* norm_b_ptr,
T* out_ptr,
T* inp_norm_ptr,
T* q_tf_ptr,
T* k_tf_ptr,
T* v_tf_ptr,
T* soft_out_ptr,
T* ctx_bufB_ptr,
T* attn_o_inp_ptr,
T* add_res_ptr,
T* ff1_inp_ptr,
T* gelu_inp_ptr,
T* ff2_inp_ptr)
{
rocblas_set_stream(_cublasHandle, _stream);
if (!_stochastic_mode) hipStreamSynchronize(_stream);
T* workspace = static_cast<T*>(Context::Instance().GetWorkSpace());
size_t small_buf_size = bsz * _seq_length * _hidden_size;
T* buf_0 = workspace;
T* buf_1 = buf_0 + small_buf_size;
T* buf_2 = buf_1;
if (_normalize_invertible) {
add_res_ptr = buf_1 + 3 * small_buf_size;
buf_2 = add_res_ptr;
}
if (_gelu_checkpoint) buf_2 += small_buf_size;
if (_attn_dropout_checkpoint)
ctx_bufB_ptr =
(_gelu_checkpoint ? (buf_2 + (_intermediate_size / _hidden_size) * small_buf_size)
: (buf_1 + 4 * small_buf_size));
int bsz_seq = bsz * _seq_length;
if (_pre_or_postLayerNorm) {
if (_layer_norm.UseMean())
_layer_norm.ForwardCheckpoint(
bsz_seq, inp_norm_ptr, input_ptr, norm_w_ptr, norm_b_ptr, _stream, true);
else
_layer_norm.Forward(
bsz_seq, inp_norm_ptr, input_ptr, norm_w_ptr, norm_b_ptr, _stream, true);
}
if (_pre_or_postLayerNorm)
_qkv_linear.Forward(bsz_seq, inp_norm_ptr, attn_qkvw_ptr, buf_0, _cublasHandle);
else
_qkv_linear.Forward(bsz_seq, input_ptr, attn_qkvw_ptr, buf_0, _cublasHandle);
launch_bias_add_transform_0213<T>(
q_tf_ptr, buf_0, attn_qkvb_ptr, bsz, _seq_length, _hidden_size, _heads, _stream, 3);
int bsz_heads = bsz * _heads;
// attention scores
_attn_scores.Forward(bsz_heads, soft_out_ptr, k_tf_ptr, q_tf_ptr, _cublasHandle);
// Softmax + Mask
_softmax.Forward(bsz, soft_out_ptr, input_mask_ptr, _stream);
// attn prob dropout.
_attn_prob_dropout.Forward(bsz_heads * _seq_length, ctx_bufB_ptr, soft_out_ptr, _stream);
// attention context
_attn_context.Forward(bsz_heads, buf_1, v_tf_ptr, ctx_bufB_ptr, _cublasHandle);
launch_transform4d_0213<T>(
attn_o_inp_ptr, buf_1, bsz, _heads, _seq_length, _hidden_size, _stream, 1);
if (_pre_or_postLayerNorm)
_attn_out_linear.Forward(bsz_seq, attn_o_inp_ptr, attn_ow_ptr, buf_1, _cublasHandle);
else
_attn_out_linear.Forward(bsz_seq, attn_o_inp_ptr, attn_ow_ptr, ff1_inp_ptr, _cublasHandle);
// attn output dropout.
if (_pre_or_postLayerNorm)
_attn_output_dropout.ForwardWithBias(
bsz_seq, add_res_ptr, buf_1, input_ptr, attn_ob_ptr, _stream);
else
_attn_output_dropout.ForwardWithBias(
bsz_seq, add_res_ptr, ff1_inp_ptr, input_ptr, attn_ob_ptr, _stream);
if (_pre_or_postLayerNorm) {
if (_attn_layer_norm.UseMean())
_attn_layer_norm.ForwardCheckpoint(
bsz_seq, ff1_inp_ptr, add_res_ptr, attn_nw_ptr, attn_nb_ptr, _stream, true);
else
_attn_layer_norm.Forward(
bsz_seq, ff1_inp_ptr, add_res_ptr, attn_nw_ptr, attn_nb_ptr, _stream, true);
} else {
if (_attn_layer_norm.UseMean())
_attn_layer_norm.ForwardCheckpoint(
bsz_seq, ff1_inp_ptr, add_res_ptr, attn_nw_ptr, attn_nb_ptr, _stream, true);
else
_attn_layer_norm.Forward(
bsz_seq, ff1_inp_ptr, add_res_ptr, attn_nw_ptr, attn_nb_ptr, _stream, true);
}
_ff1.Forward(bsz_seq,
ff1_inp_ptr,
inter_w_ptr,
(_gelu_checkpoint ? ff2_inp_ptr : gelu_inp_ptr),
_cublasHandle);
_gelu.ForwardWithBiasAdd(bsz_seq,
(_gelu_checkpoint ? ff2_inp_ptr : gelu_inp_ptr),
inter_b_ptr,
(_gelu_checkpoint ? buf_2 : ff2_inp_ptr),
_stream);
_ff2.Forward(
bsz_seq, (_gelu_checkpoint ? buf_2 : ff2_inp_ptr), output_w_ptr, out_ptr, _cublasHandle);
// layer output dropout.
if (_pre_or_postLayerNorm)
_layer_output_dropout.ForwardWithBias(
bsz_seq, out_ptr, out_ptr, add_res_ptr, output_b_ptr, _stream);
else
_layer_output_dropout.ForwardWithBias(
bsz_seq, inp_norm_ptr, out_ptr, ff1_inp_ptr, output_b_ptr, _stream);
if (!_pre_or_postLayerNorm) {
if (_layer_norm.UseMean())
_layer_norm.ForwardCheckpoint(
bsz_seq, out_ptr, inp_norm_ptr, norm_w_ptr, norm_b_ptr, _stream, true);
else
_layer_norm.Forward(
bsz_seq, out_ptr, inp_norm_ptr, norm_w_ptr, norm_b_ptr, _stream, true);
}
}
template <typename T>
void BertTransformerLayer<T>::Backward(int bsz,
const T* grad_output_ptr,
const T* input_ptr,
const T* output_ptr,
const T* inp_norm_ptr,
const T* q_tf_ptr,
const T* k_tf_ptr,
const T* v_tf_ptr,
const T* soft_out_ptr,
const T* ctx_bufB_ptr,
const T* attn_o_inp_ptr,
const T* add_res_ptr,
const T* ff1_inp_ptr,
const T* gelu_inp_ptr,
const T* ff2_inp_ptr,
const T* input_mask_ptr,
const T* attn_qkvw_ptr,
const T* attn_ow_ptr,
const T* attn_nw_ptr,
const T* attn_nb_ptr,
const T* inter_w_ptr,
const T* inter_b_ptr,
const T* output_w_ptr,
const T* norm_w_ptr,
const T* norm_b_ptr,
T* grad_input_ptr,
T* grad_attn_qkvw_ptr,
T* grad_attn_qkvb_ptr,
T* grad_attn_ow_ptr,
T* grad_attn_ob_ptr,
T* grad_attn_nw_ptr,
T* grad_attn_nb_ptr,
T* grad_inter_w_ptr,
T* grad_inter_b_ptr,
T* grad_output_w_ptr,
T* grad_output_b_ptr,
T* grad_norm_w_ptr,
T* grad_norm_b_ptr)
{
rocblas_set_stream(_cublasHandle, _stream);
if (!_stochastic_mode) hipStreamSynchronize(_stream);
T* workspace = static_cast<T*>(Context::Instance().GetWorkSpace());
size_t small_buf_size = bsz * _seq_length * _hidden_size;
T* buf_0 = workspace;
T* buf_1 = buf_0 + small_buf_size;
T* buf_2 = buf_1 + small_buf_size;
T* buf_3 = buf_2 + small_buf_size;
T* ff2_buf = (_gelu_checkpoint ? buf_3 + (bsz * _seq_length * _intermediate_size)
: buf_3 + small_buf_size);
T* ctx_bufB_ptr_recomp = ff2_buf + (_seq_length * _seq_length * bsz * _heads);
hipStream_t streams[2] = {_stream, _stream};
int bsz_seq = bsz * _seq_length;
int bsz_heads = bsz * _heads;
if (!_pre_or_postLayerNorm) {
if (_layer_norm.UseMean())
_layer_norm.Backward(bsz_seq,
grad_output_ptr,
norm_w_ptr,
grad_norm_w_ptr,
grad_norm_b_ptr,
streams,
buf_1,
inp_norm_ptr);
else
_layer_norm.Backward(bsz_seq,
grad_output_ptr,
norm_w_ptr,
norm_b_ptr,
grad_norm_w_ptr,
grad_norm_b_ptr,
streams,
buf_1,
output_ptr);
}
if (_pre_or_postLayerNorm)
_layer_output_dropout.Backward(bsz_seq, buf_0, grad_output_ptr, _stream);
else
_layer_output_dropout.Backward(bsz_seq, buf_0, buf_1, _stream);
const T* layer_dropout_buf = _layer_output_dropout.HasDropout()
? buf_0
: (_pre_or_postLayerNorm ? grad_output_ptr : buf_1);
if (_gelu_checkpoint)
_gelu.ForwardWithBiasAdd(bsz_seq, ff2_inp_ptr, inter_b_ptr, buf_2, _stream);
_ff2.Backward(bsz_seq,
layer_dropout_buf,
(_gelu_checkpoint ? buf_2 : ff2_inp_ptr),
output_w_ptr,
grad_output_w_ptr,
grad_output_b_ptr,
_cublasHandle,
_stream,
ff2_buf);
_gelu.Backward(
bsz_seq, ff2_buf, (_gelu_checkpoint ? ff2_inp_ptr : gelu_inp_ptr), inter_b_ptr, _stream);
_ff1.Backward(bsz_seq,
ff2_buf,
ff1_inp_ptr,
inter_w_ptr,
grad_inter_w_ptr,
grad_inter_b_ptr,
_cublasHandle,
_stream,
buf_3);
if (!_pre_or_postLayerNorm)
launch_fused_add2<T>(buf_2, buf_3, buf_1, bsz, _seq_length, _hidden_size, _stream);
if (_pre_or_postLayerNorm) {
if (_attn_layer_norm.UseMean())
_attn_layer_norm.BackwardFusedAdd(bsz_seq,
buf_3,
grad_output_ptr,
attn_nw_ptr,
grad_attn_nw_ptr,
grad_attn_nb_ptr,
streams,
buf_0,
add_res_ptr);
else
_attn_layer_norm.BackwardFusedAdd(bsz_seq,
buf_3,
grad_output_ptr,
attn_nw_ptr,
attn_nb_ptr,
grad_attn_nw_ptr,
grad_attn_nb_ptr,
streams,
buf_0,
ff1_inp_ptr);
} else {
if (_attn_layer_norm.UseMean())
_attn_layer_norm.Backward(bsz_seq,
buf_2,
attn_nw_ptr,
grad_attn_nw_ptr,
grad_attn_nb_ptr,
streams,
buf_0,
add_res_ptr);
else
_attn_layer_norm.Backward(bsz_seq,
buf_2,
attn_nw_ptr,
attn_nb_ptr,
grad_attn_nw_ptr,
grad_attn_nb_ptr,
streams,
buf_0,
ff1_inp_ptr);
}
_attn_output_dropout.Backward(bsz_seq, buf_2, buf_0, _stream);
T* attn_output_dropout_buf = _attn_output_dropout.HasDropout() ? buf_2 : buf_0;
_attn_out_linear.Backward(bsz_seq,
attn_output_dropout_buf,
attn_o_inp_ptr,
attn_ow_ptr,
grad_attn_ow_ptr,
grad_attn_ob_ptr,
_cublasHandle,
_stream,
buf_1);
launch_transform_0213<T>(buf_2, buf_1, bsz, _seq_length, _hidden_size, _heads, _stream);
if (_attn_prob_dropout.HasDropout()) {
if (_attn_dropout_checkpoint)
_attn_prob_dropout.Forward(
bsz_heads * _seq_length, ctx_bufB_ptr_recomp, soft_out_ptr, _stream, true);
_attn_context.Backward(bsz_heads,
buf_2,
v_tf_ptr,
(_attn_dropout_checkpoint ? ctx_bufB_ptr_recomp : ctx_bufB_ptr),
_cublasHandle,
buf_3,
ff2_buf);
} else
_attn_context.Backward(
bsz_heads, buf_2, v_tf_ptr, soft_out_ptr, _cublasHandle, buf_3, ff2_buf);
_attn_prob_dropout.Backward(bsz_heads * _seq_length, ff2_buf, _stream);
_softmax.Backward(bsz, ff2_buf, soft_out_ptr, _stream);
_attn_scores.Backward(bsz_heads, ff2_buf, k_tf_ptr, q_tf_ptr, _cublasHandle, buf_2, buf_1);
launch_transform4d_0213(ff2_buf, buf_1, bsz, _heads, _seq_length, _hidden_size, _stream, 3);
if (_pre_or_postLayerNorm)
_qkv_linear.Backward(bsz_seq,
ff2_buf,
inp_norm_ptr,
attn_qkvw_ptr,
grad_attn_qkvw_ptr,
grad_attn_qkvb_ptr,
_cublasHandle,
_stream,
buf_2);
else
_qkv_linear.Backward(bsz_seq,
ff2_buf,
input_ptr,
attn_qkvw_ptr,
grad_attn_qkvw_ptr,
grad_attn_qkvb_ptr,
_cublasHandle,
_stream,
buf_2);
if (_pre_or_postLayerNorm) {
if (_layer_norm.UseMean())
_layer_norm.BackwardFusedAdd(bsz_seq,
buf_2,
buf_0,
norm_w_ptr,
grad_norm_w_ptr,
grad_norm_b_ptr,
streams,
grad_input_ptr,
input_ptr);
else
_layer_norm.BackwardFusedAdd(bsz_seq,
buf_2,
buf_0,
norm_w_ptr,
norm_b_ptr,
grad_norm_w_ptr,
grad_norm_b_ptr,
streams,
grad_input_ptr,
inp_norm_ptr);
} else
launch_fused_add2<T>(grad_input_ptr, buf_2, buf_0, bsz, _seq_length, _hidden_size, _stream);
}
template <typename T>
void BertTransformerLayer<T>::SetTrainingMode(bool training)
{
// Dropout will be skipped when not in training model.
_attn_prob_dropout.SetTrainingMode(training);
_attn_output_dropout.SetTrainingMode(training);
_layer_output_dropout.SetTrainingMode(training);
}
template <typename T>
void BertTransformerLayer<T>::SetIntermediateBuffers(uint8_t* attn_prob_dropout_mask_ptr,
uint8_t* attn_output_dropout_mask_ptr,
uint8_t* layer_output_dropout_mask_ptr,
T* attn_layer_norm_var,
T* attn_layer_norm_mean,
T* layer_norm_var,
T* layer_norm_mean)
{
_attn_prob_dropout.SetMask(attn_prob_dropout_mask_ptr);
_attn_output_dropout.SetMask(attn_output_dropout_mask_ptr);
_layer_output_dropout.SetMask(layer_output_dropout_mask_ptr);
_attn_layer_norm.SetVar(attn_layer_norm_var);
_attn_layer_norm.SetMean(attn_layer_norm_mean);
_layer_norm.SetVar(layer_norm_var);
_layer_norm.SetMean(layer_norm_mean);
}
template <typename T>
void BertTransformerLayer<T>::SetSeqLength(int seq_len)
{
_seq_length = seq_len;
_softmax.SetSeqLength(_seq_length);
_attn_prob_dropout.SetDimension(_seq_length);
_attn_scores.SetConfig(_seq_length, _seq_length, _hidden_size / _heads);
_attn_context.SetConfig(_hidden_size / _heads, _seq_length, _seq_length);
}
template <typename T>
int create_transformer_layer(int layer_id,
int batch_size,
int hidden_dim,
int num_heads,
int intermediate_size,
float attn_dropout_ratio,
float hidden_dropout_ratio,
float layer_norm_eps,
int seed,
bool pre_or_postLayerNorm,
bool test_gemm,
bool attn_dropout_checkpoint,
bool normalize_invertible,
bool gelu_checkpoint,
bool stochastic_mode)
{
Context::Instance().SetSeed(seed);
Context::Instance().TestGemmFP16(
test_gemm, batch_size, init_seq_length, num_heads, hidden_dim / num_heads);
auto layer = std::make_shared<BertTransformerLayer<T>>(layer_id,
batch_size,
hidden_dim,
num_heads,
intermediate_size,
init_seq_length,
attn_dropout_ratio,
hidden_dropout_ratio,
layer_norm_eps,
pre_or_postLayerNorm,
Context::Instance().GetGemmAlgos(),
attn_dropout_checkpoint,
normalize_invertible,
gelu_checkpoint,
stochastic_mode);
s_transformer_layers[layer_id] = layer;
std::string dtype = (std::is_same<T, __half>::value) ? "half" : "float";
std::cout << "layer #" << layer_id << " is created with date type [" << dtype << "]."
<< std::endl;
return 0;
}
template <typename T>
std::vector<torch::Tensor> ds_transformer_forward(int layer_id,
const torch::Tensor& input,
const torch::Tensor& input_mask,
const torch::Tensor& attn_qkvw,
const torch::Tensor& attn_qkvb,
const torch::Tensor& attn_ow,
const torch::Tensor& attn_ob,
const torch::Tensor& attn_nw,
const torch::Tensor& attn_nb,
const torch::Tensor& inter_w,
const torch::Tensor& inter_b,
const torch::Tensor& output_w,
const torch::Tensor& output_b,
const torch::Tensor& norm_w,
const torch::Tensor& norm_b,
bool training_mode,
bool prelayernorm,
bool attn_dropout_checkpoint,
bool normalize_invertible,
bool gelu_checkpoint)
{
CHECK_INPUT(input);
CHECK_INPUT(input_mask);
CHECK_INPUT(attn_qkvw);
CHECK_INPUT(attn_qkvb);
CHECK_INPUT(attn_ow);
CHECK_INPUT(attn_ob);
CHECK_INPUT(attn_nw);
CHECK_INPUT(attn_nb);
CHECK_INPUT(inter_w);
CHECK_INPUT(inter_b);
CHECK_INPUT(output_w);
CHECK_INPUT(output_b);
CHECK_INPUT(norm_w);
CHECK_INPUT(norm_b);
int bsz = input.size(0);
const T* input_ptr = (const T*)input.data_ptr();
const T* input_mask_ptr = (const T*)input_mask.data_ptr();
const T* attn_qkvw_ptr = (const T*)attn_qkvw.data_ptr();
const T* attn_qkvb_ptr = (const T*)attn_qkvb.data_ptr();
const T* attn_ow_ptr = (const T*)attn_ow.data_ptr();
const T* attn_ob_ptr = (const T*)attn_ob.data_ptr();
const T* attn_nw_ptr = (const T*)attn_nw.data_ptr();
const T* attn_nb_ptr = (const T*)attn_nb.data_ptr();
const T* inter_w_ptr = (const T*)inter_w.data_ptr();
const T* inter_b_ptr = (const T*)inter_b.data_ptr();
const T* output_w_ptr = (const T*)output_w.data_ptr();
const T* output_b_ptr = (const T*)output_b.data_ptr();
const T* norm_w_ptr = (const T*)norm_w.data_ptr();
const T* norm_b_ptr = (const T*)norm_b.data_ptr();
auto output = torch::empty_like(input);
T* out_ptr = (T*)output.data_ptr();
auto options = torch::TensorOptions()
.dtype(input.options().dtype())
.layout(torch::kStrided)
.device(torch::kCUDA)
.requires_grad(true);
auto uint8_options = torch::TensorOptions()
.dtype(torch::kInt8)
.layout(torch::kStrided)
.device(torch::kCUDA)
.requires_grad(false);
std::shared_ptr<BertTransformerLayer<T>> layer =
std::static_pointer_cast<BertTransformerLayer<T>>(s_transformer_layers[layer_id]);
int seq_len = layer->GetSeqLength();
if (input.size(1) != seq_len) {
seq_len = input.size(1);
layer->SetSeqLength(seq_len);
}
auto workspace = torch::empty({get_workspace_size<T>(bsz,
seq_len,
layer->GetHiddenSize(),
layer->GetIntermediateSize(),
layer->GetNumHeads(),
layer->IsTrainingMode(),
layer->GeluCheckpoint())},
options);
Context::Instance().SetWorkSpace((T*)workspace.data_ptr());
auto inp_norm = ((prelayernorm || !normalize_invertible) ? torch::empty_like(input) : output);
auto add_res = (normalize_invertible ? inp_norm : torch::empty_like(input));
auto attn_o_inp = torch::empty_like(input);
auto qkv_tf = torch::empty({(bsz * seq_len), output_w.size(0) * 3}, options);
auto attn_prob_dropout_mask =
torch::empty({(bsz * layer->GetNumHeads() * seq_len), seq_len}, uint8_options);
auto attn_output_dropout_mask =
torch::empty({(bsz * seq_len), layer->GetHiddenSize()}, uint8_options);
auto layer_output_dropout_mask =
torch::empty({(bsz * seq_len), layer->GetHiddenSize()}, uint8_options);
auto attn_layer_norm_var = torch::empty({(bsz * seq_len)}, options);
auto attn_layer_norm_mean = torch::empty({(bsz * seq_len)}, options);
auto layer_norm_var = torch::empty({(bsz * seq_len)}, options);
auto layer_norm_mean = torch::empty({(bsz * seq_len)}, options);
T* inp_norm_ptr = (T*)inp_norm.data_ptr();
T* add_res_ptr = (T*)add_res.data_ptr();
T* q_tf_ptr = (T*)qkv_tf.data_ptr();
T* k_tf_ptr = q_tf_ptr + (bsz * seq_len * output_w.size(0)); //(T*)k_tf.data_ptr();
T* v_tf_ptr = k_tf_ptr + (bsz * seq_len * output_w.size(0)); //(T*)v_tf.data_ptr();
T* attn_o_inp_ptr = (T*)attn_o_inp.data_ptr();
torch::Tensor ff2_inp = torch::empty({(bsz * seq_len), output_w.size(1)}, options);
torch::Tensor gelu_inp =
(gelu_checkpoint ? ff2_inp : torch::empty({(bsz * seq_len), output_w.size(1)}, options));
auto ff1_inp = torch::empty_like(input);
T* ff2_inp_ptr = (T*)ff2_inp.data_ptr();
T* gelu_inp_ptr = (T*)gelu_inp.data_ptr();
T* ff1_inp_ptr = (T*)ff1_inp.data_ptr();
torch::Tensor soft_out =
torch::empty({(bsz * layer->GetNumHeads() * seq_len), seq_len}, options);
torch::Tensor ctx_bufB =
(attn_dropout_checkpoint
? soft_out
: torch::empty({(bsz * layer->GetNumHeads() * seq_len), seq_len}, options));
T* soft_out_ptr = (T*)soft_out.data_ptr();
T* ctx_bufB_ptr = (T*)ctx_bufB.data_ptr();
layer->SetTrainingMode(training_mode);
layer->SetIntermediateBuffers((uint8_t*)attn_prob_dropout_mask.data_ptr(),
(uint8_t*)attn_output_dropout_mask.data_ptr(),
(uint8_t*)layer_output_dropout_mask.data_ptr(),
(T*)attn_layer_norm_var.data_ptr(),
(T*)attn_layer_norm_mean.data_ptr(),
(T*)layer_norm_var.data_ptr(),
(T*)layer_norm_mean.data_ptr());
layer->Forward(bsz,
input_ptr,
input_mask_ptr,
attn_qkvw_ptr,
attn_qkvb_ptr,
attn_ow_ptr,
attn_ob_ptr,
attn_nw_ptr,
attn_nb_ptr,
inter_w_ptr,
inter_b_ptr,
output_w_ptr,
output_b_ptr,
norm_w_ptr,
norm_b_ptr,
out_ptr,
inp_norm_ptr,
q_tf_ptr,
k_tf_ptr,
v_tf_ptr,
soft_out_ptr,
ctx_bufB_ptr,
attn_o_inp_ptr,
add_res_ptr,
ff1_inp_ptr,
gelu_inp_ptr,
ff2_inp_ptr);
return {output,
inp_norm,
qkv_tf,
soft_out,
ctx_bufB,
attn_o_inp,
add_res,
ff1_inp,
gelu_inp,
ff2_inp,
attn_prob_dropout_mask,
attn_output_dropout_mask,
layer_output_dropout_mask,
attn_layer_norm_var,
attn_layer_norm_mean,
layer_norm_var,
layer_norm_mean};
}
template <typename T>
std::vector<torch::Tensor> ds_transformer_backward(int layer_id,
const torch::Tensor& grad_output,
const torch::Tensor& output,
const torch::Tensor& inp_norm,
const torch::Tensor& qkv_tf,
const torch::Tensor& soft_out,
const torch::Tensor& ctx_bufB,
const torch::Tensor& attn_o_inp,
const torch::Tensor& add_res,
const torch::Tensor& ff1_inp,
const torch::Tensor& gelu_inp,
const torch::Tensor& ff2_inp,
const torch::Tensor& attn_prob_dropout_mask,
const torch::Tensor& attn_output_dropout_mask,
const torch::Tensor& layer_output_dropout_mask,
const torch::Tensor& attn_layer_norm_var,
const torch::Tensor& attn_layer_norm_mean,
const torch::Tensor& layer_norm_var,
const torch::Tensor& layer_norm_mean,
const torch::Tensor& input,
const torch::Tensor& input_mask,
const torch::Tensor& attn_qkvw,
const torch::Tensor& attn_qkvb,
const torch::Tensor& attn_ow,
const torch::Tensor& attn_ob,
const torch::Tensor& attn_nw,
const torch::Tensor& attn_nb,
const torch::Tensor& inter_w,
const torch::Tensor& inter_b,
const torch::Tensor& output_w,
const torch::Tensor& output_b,
const torch::Tensor& norm_w,
const torch::Tensor& norm_b)
{
auto g_output = grad_output.contiguous();
CHECK_INPUT(g_output);
CHECK_INPUT(output);
CHECK_INPUT(inp_norm);
CHECK_INPUT(qkv_tf);
CHECK_INPUT(add_res);
CHECK_INPUT(soft_out);
CHECK_INPUT(ctx_bufB);
CHECK_INPUT(attn_o_inp);
CHECK_INPUT(ff1_inp);
CHECK_INPUT(gelu_inp);
CHECK_INPUT(ff2_inp);
CHECK_INPUT(input);
CHECK_INPUT(input_mask);
CHECK_INPUT(attn_qkvw);
CHECK_INPUT(attn_qkvb);
CHECK_INPUT(attn_ow);
CHECK_INPUT(attn_ob);
CHECK_INPUT(attn_nw);
CHECK_INPUT(attn_nb);
CHECK_INPUT(inter_w);
CHECK_INPUT(inter_b);
CHECK_INPUT(output_w);
CHECK_INPUT(output_b);
CHECK_INPUT(norm_w);
CHECK_INPUT(norm_b);
int bsz = g_output.size(0);
std::shared_ptr<BertTransformerLayer<T>> layer =
std::static_pointer_cast<BertTransformerLayer<T>>(s_transformer_layers[layer_id]);
int seq_len = layer->GetSeqLength();
if (g_output.size(1) != seq_len) {
seq_len = g_output.size(1);
layer->SetSeqLength(seq_len);
}
auto options = torch::TensorOptions()
.dtype(g_output.options().dtype())
.layout(torch::kStrided)
.device(torch::kCUDA)
.requires_grad(true);
auto workspace = torch::empty({get_workspace_size<T>(bsz,
seq_len,
layer->GetHiddenSize(),
layer->GetIntermediateSize(),
layer->GetNumHeads(),
layer->IsTrainingMode(),
layer->GeluCheckpoint())},
options);
Context::Instance().SetWorkSpace((T*)workspace.data_ptr());
auto grad_input = torch::empty_like(input);
auto grad_attn_qkvw = torch::empty_like(attn_qkvw);
auto grad_attn_qkvb = torch::empty_like(attn_qkvb);
auto grad_attn_ow = torch::empty_like(attn_ow);
auto grad_attn_ob = torch::empty_like(attn_ob);
auto grad_attn_nw = torch::empty_like(attn_nw);
auto grad_attn_nb = torch::empty_like(attn_nb);
auto grad_inter_w = torch::empty_like(inter_w);
auto grad_inter_b = torch::empty_like(inter_b);
auto grad_output_w = torch::empty_like(output_w);
auto grad_output_b = torch::empty_like(output_b);
auto grad_norm_w = torch::empty_like(norm_w);
auto grad_norm_b = torch::empty_like(norm_b);
// inputs.
const T* grad_output_ptr = (const T*)g_output.data_ptr();
const T* input_ptr = (const T*)input.data_ptr();
const T* output_ptr = (const T*)output.data_ptr();
const T* inp_norm_ptr = (const T*)inp_norm.data_ptr();
const T* q_tf_ptr = (const T*)qkv_tf.data_ptr();
const T* add_res_ptr = (const T*)add_res.data_ptr();
const T* k_tf_ptr =
q_tf_ptr + (bsz * layer->GetSeqLength() * output_w.size(0)); //(const T*)k_tf.data_ptr();
const T* v_tf_ptr =
k_tf_ptr + (bsz * layer->GetSeqLength() * output_w.size(0)); //(const T*)v_tf.data_ptr();
const T* ff1_inp_ptr = (const T*)ff1_inp.data_ptr();
const T* gelu_inp_ptr = (const T*)gelu_inp.data_ptr();
const T* ff2_inp_ptr = (const T*)ff2_inp.data_ptr();
const T* ctx_bufB_ptr = (const T*)ctx_bufB.data_ptr();
const T* soft_out_ptr = (const T*)soft_out.data_ptr();
const T* attn_o_inp_ptr = (const T*)attn_o_inp.data_ptr();
const T* input_mask_ptr = (const T*)input_mask.data_ptr();
const T* attn_qkvw_ptr = (const T*)attn_qkvw.data_ptr();
const T* attn_ow_ptr = (const T*)attn_ow.data_ptr();
const T* attn_nw_ptr = (const T*)attn_nw.data_ptr();
const T* attn_nb_ptr = (const T*)attn_nb.data_ptr();
const T* inter_w_ptr = (const T*)inter_w.data_ptr();
const T* inter_b_ptr = (const T*)inter_b.data_ptr();
const T* output_w_ptr = (const T*)output_w.data_ptr();
const T* norm_w_ptr = (const T*)norm_w.data_ptr();
const T* norm_b_ptr = (const T*)norm_b.data_ptr();
// outputs.
T* grad_input_ptr = (T*)grad_input.data_ptr();
T* grad_attn_qkvw_ptr = (T*)grad_attn_qkvw.data_ptr();
T* grad_attn_qkvb_ptr = (T*)grad_attn_qkvb.data_ptr();
T* grad_attn_ow_ptr = (T*)grad_attn_ow.data_ptr();
T* grad_attn_ob_ptr = (T*)grad_attn_ob.data_ptr();
T* grad_attn_nw_ptr = (T*)grad_attn_nw.data_ptr();
T* grad_attn_nb_ptr = (T*)grad_attn_nb.data_ptr();
T* grad_inter_w_ptr = (T*)grad_inter_w.data_ptr();
T* grad_inter_b_ptr = (T*)grad_inter_b.data_ptr();
T* grad_output_w_ptr = (T*)grad_output_w.data_ptr();
T* grad_output_b_ptr = (T*)grad_output_b.data_ptr();
T* grad_norm_w_ptr = (T*)grad_norm_w.data_ptr();
T* grad_norm_b_ptr = (T*)grad_norm_b.data_ptr();
layer->SetIntermediateBuffers((uint8_t*)attn_prob_dropout_mask.data_ptr(),
(uint8_t*)attn_output_dropout_mask.data_ptr(),
(uint8_t*)layer_output_dropout_mask.data_ptr(),
(T*)attn_layer_norm_var.data_ptr(),
(T*)attn_layer_norm_mean.data_ptr(),
(T*)layer_norm_var.data_ptr(),
(T*)layer_norm_mean.data_ptr());
layer->Backward(bsz,
grad_output_ptr,
input_ptr,
output_ptr,
inp_norm_ptr,
q_tf_ptr,
k_tf_ptr,
v_tf_ptr,
soft_out_ptr,
ctx_bufB_ptr,
attn_o_inp_ptr,
add_res_ptr,
ff1_inp_ptr,
gelu_inp_ptr,
ff2_inp_ptr,
input_mask_ptr,
attn_qkvw_ptr,
attn_ow_ptr,
attn_nw_ptr,
attn_nb_ptr,
inter_w_ptr,
inter_b_ptr,
output_w_ptr,
norm_w_ptr,
norm_b_ptr,
grad_input_ptr,
grad_attn_qkvw_ptr,
grad_attn_qkvb_ptr,
grad_attn_ow_ptr,
grad_attn_ob_ptr,
grad_attn_nw_ptr,
grad_attn_nb_ptr,
grad_inter_w_ptr,
grad_inter_b_ptr,
grad_output_w_ptr,
grad_output_b_ptr,
grad_norm_w_ptr,
grad_norm_b_ptr);
return {grad_input,
grad_attn_qkvw,
grad_attn_qkvb,
grad_attn_ow,
grad_attn_ob,
grad_attn_nw,
grad_attn_nb,
grad_inter_w,
grad_inter_b,
grad_output_w,
grad_output_b,
grad_norm_w,
grad_norm_b};
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
{
m.def("forward_fp32",
&ds_transformer_forward<float>,
"DeepSpeed Transformer forward with fp32 (CUDA)");
m.def("forward_fp16",
&ds_transformer_forward<__half>,
"DeepSpeed Transformer forward with fp16 (CUDA)");
m.def("backward_fp32",
&ds_transformer_backward<float>,
"DeepSpeed Transformer backward with fp32 (CUDA)");
m.def("backward_fp16",
&ds_transformer_backward<__half>,
"DeepSpeed Transformer backward with fp16 (CUDA)");
m.def("create_transformer_layer_fp32",
&create_transformer_layer<float>,
"Create DeepSpeed Transformer Transformer Layer with fp32 (CUDA)");
m.def("create_transformer_layer_fp16",
&create_transformer_layer<__half>,
"Create DeepSpeed Transformer Transformer Layer with fp16 (CUDA)");
}
#include "hip/hip_runtime.h"
#include "hip/custom_hip_layers.h"
inline __device__ float gelu(const float x)
{
const float sqrt_param = 0.79788456080286535587989211986876f;
const float mul_param = 0.044715;
return x * 0.5f * (1.0f + tanhf(sqrt_param * (x + mul_param * x * x * x)));
}
inline __device__ float d_gelu(const float x)
{
const float sqrt_param = 0.79788456080286535587989211986876f;
const float mul_param = 0.044715;
float x2mul = x * x * mul_param;
float tan_h = tanhf(sqrt_param * (x + x * x2mul));
float dg1 = 0.5f * (1.0f + tan_h);
float dg2 = x * 0.5f * sqrt_param * (1 - tan_h * tan_h);
float dg3 = dg2 * 3 * x2mul;
return (dg1 + dg2 + dg3);
}
/*
Fused bias add with GELU
Loads a vector of 4 elements each iteration, for stride
iterations. It was written with the intention to launch 256 thread
threadblocks, so to launch for bert-large, we would set ITERATIONS
to 4. This is currently done automatically as a heuristic, setting
the number of iterations as blocks of 1024.
For FP16, the values are loaded from memory as __half, but converted
to FP32 for the arithmetic itself, to prevent numerous overflow on
the intermediate hyperbolic tangent, since there's no intrinsic
that computes it directly.
*/
__global__ void gelu_kernel(const float* input, float* vals, int intermediate_size)
{
int row = blockIdx.x;
int id = threadIdx.x;
int loop_stride = blockDim.x;
int iterations = intermediate_size / blockDim.x / 4;
int row_stride = intermediate_size / 4;
const float4* input_cast = reinterpret_cast<const float4*>(input);
float4* vals_cast = reinterpret_cast<float4*>(vals);
for (int i = 0; i < iterations; i++) {
if (i * loop_stride + id < row_stride) {
float4 data = input_cast[row * row_stride + i * loop_stride + id];
data.x = gelu(data.x);
data.y = gelu(data.y);
data.z = gelu(data.z);
data.w = gelu(data.w);
vals_cast[row * row_stride + i * loop_stride + id] = data;
}
}
}
__global__ void gelu_kernel(const __half* input, __half* vals, int intermediate_size)
{
#if __CUDA_ARCH__ >= 700
int row = blockIdx.x;
int id = threadIdx.x;
int loop_stride = blockDim.x;
int iterations = intermediate_size / blockDim.x / 4;
int row_stride = intermediate_size / 4;
const float2* input_cast = reinterpret_cast<const float2*>(input);
float2* vals_cast = reinterpret_cast<float2*>(vals);
for (int i = 0; i < iterations; i++) {
if (i * loop_stride + id < row_stride) {
float2 vals_vec = input_cast[row * row_stride + i * loop_stride + id];
__half2* vals_half = reinterpret_cast<__half2*>(&vals_vec);
float2 low_data = __half22float2(vals_half[0]);
float2 high_data = __half22float2(vals_half[1]);
low_data.x = gelu(low_data.x);
low_data.y = gelu(low_data.y);
high_data.x = gelu(high_data.x);
high_data.y = gelu(high_data.y);
vals_half[0] = __float22half2_rn(low_data);
vals_half[1] = __float22half2_rn(high_data);
vals_cast[row * row_stride + i * loop_stride + id] = vals_vec;
}
}
#endif
}
__global__ void fused_bias_gelu(const float* input,
const float* bias,
float* vals,
int intermediate_size)
{
int row = blockIdx.x;
int id = threadIdx.x;
int loop_stride = blockDim.x;
int iterations = intermediate_size / blockDim.x / 4;
int row_stride = intermediate_size / 4;
const float4* input_cast = reinterpret_cast<const float4*>(input);
float4* vals_cast = reinterpret_cast<float4*>(vals);
const float4* bias_cast = reinterpret_cast<const float4*>(bias);
for (int i = 0; i < iterations; i++) {
if (i * loop_stride + id < row_stride) {
float4 data = input_cast[row * row_stride + i * loop_stride + id];
float4 bias_data = bias_cast[i * loop_stride + id];
data.x += bias_data.x;
data.y += bias_data.y;
data.z += bias_data.z;
data.w += bias_data.w;
data.x = gelu(data.x);
data.y = gelu(data.y);
data.z = gelu(data.z);
data.w = gelu(data.w);
vals_cast[row * row_stride + i * loop_stride + id] = data;
}
}
}
__global__ void fused_bias_gelu(const __half* input,
const __half* bias,
__half* vals,
int intermediate_size)
{
#if __CUDA_ARCH__ >= 700
int row = blockIdx.x;
int id = threadIdx.x;
int loop_stride = blockDim.x;
int iterations = intermediate_size / blockDim.x / 4;
int row_stride = intermediate_size / 4;
const float2* input_cast = reinterpret_cast<const float2*>(input);
float2* vals_cast = reinterpret_cast<float2*>(vals);
const float2* bias_cast = reinterpret_cast<const float2*>(bias);
for (int i = 0; i < iterations; i++) {
if (i * loop_stride + id < row_stride) {
float2 vals_vec = input_cast[row * row_stride + i * loop_stride + id];
float2 bias_vec = bias_cast[i * loop_stride + id];
__half2* vals_half = reinterpret_cast<__half2*>(&vals_vec);
__half2* bias_half = reinterpret_cast<__half2*>(&bias_vec);
float2 low_data = __half22float2(vals_half[0]);
float2 high_data = __half22float2(vals_half[1]);
float2 low_bias = __half22float2(bias_half[0]);
float2 high_bias = __half22float2(bias_half[1]);
low_data.x += low_bias.x;
low_data.y += low_bias.y;
high_data.x += high_bias.x;
high_data.y += high_bias.y;
low_data.x = gelu(low_data.x);
low_data.y = gelu(low_data.y);
high_data.x = gelu(high_data.x);
high_data.y = gelu(high_data.y);
vals_half[0] = __float22half2_rn(low_data);
vals_half[1] = __float22half2_rn(high_data);
vals_cast[row * row_stride + i * loop_stride + id] = vals_vec;
}
}
#endif
}
__global__ void d_gelu_func(float* d_output,
const float* gelu_input,
const float* bias,
int intermediate_size)
{
int row = blockIdx.x;
int id = threadIdx.x;
int loop_stride = blockDim.x;
int iterations = intermediate_size / blockDim.x / 4;
int row_stride = intermediate_size / 4;
float4* d_output_cast = reinterpret_cast<float4*>(d_output);
const float4* gelu_input_cast = reinterpret_cast<const float4*>(gelu_input);
const float4* bias_cast = reinterpret_cast<const float4*>(bias);
for (int i = 0; i < iterations; i++) {
if (i * loop_stride + id < row_stride) {
float4 output_data = d_output_cast[row * row_stride + i * loop_stride + id];
float4 gelu_input_data = gelu_input_cast[row * row_stride + i * loop_stride + id];
float4 bias_data = bias_cast[i * loop_stride + id];
gelu_input_data.x += bias_data.x;
gelu_input_data.y += bias_data.y;
gelu_input_data.z += bias_data.z;
gelu_input_data.w += bias_data.w;
output_data.x *= d_gelu(gelu_input_data.x);
output_data.y *= d_gelu(gelu_input_data.y);
output_data.z *= d_gelu(gelu_input_data.z);
output_data.w *= d_gelu(gelu_input_data.w);
d_output_cast[row * row_stride + i * loop_stride + id] = output_data;
}
}
}
__global__ void d_gelu_func(__half* d_output,
const __half* gelu_input,
const __half* bias,
int intermediate_size)
{
#if __CUDA_ARCH__ >= 700
int row = blockIdx.x;
int id = threadIdx.x;
int loop_stride = blockDim.x;
int iterations = intermediate_size / blockDim.x / 4;
int row_stride = intermediate_size / 4;
float2* d_output_cast = reinterpret_cast<float2*>(d_output);
const float2* gelu_input_cast = reinterpret_cast<const float2*>(gelu_input);
const float2* bias_cast = reinterpret_cast<const float2*>(bias);
#pragma unroll
for (int i = 0; i < iterations; i++) {
if (i * loop_stride + id < row_stride) {
float2 output_data = d_output_cast[row * row_stride + i * loop_stride + id];
float2 gelu_input_data = gelu_input_cast[row * row_stride + i * loop_stride + id];
float2 bias_vec = bias_cast[i * loop_stride + id];
__half2* output_data_half = reinterpret_cast<__half2*>(&output_data);
__half2* gelu_input_data_half = reinterpret_cast<__half2*>(&gelu_input_data);
__half2* bias_half = reinterpret_cast<__half2*>(&bias_vec);
float2 output_half_0 = __half22float2(output_data_half[0]);
float2 output_half_1 = __half22float2(output_data_half[1]);
float2 gelu_input_half_0 = __half22float2(gelu_input_data_half[0]);
float2 gelu_input_half_1 = __half22float2(gelu_input_data_half[1]);
float2 bias_half_0 = __half22float2(bias_half[0]);
float2 bias_half_1 = __half22float2(bias_half[1]);
gelu_input_half_0.x += bias_half_0.x;
gelu_input_half_0.y += bias_half_0.y;
gelu_input_half_1.x += bias_half_1.x;
gelu_input_half_1.y += bias_half_1.y;
output_half_0.x *= d_gelu(gelu_input_half_0.x);
output_half_0.y *= d_gelu(gelu_input_half_0.y);
output_half_1.x *= d_gelu(gelu_input_half_1.x);
output_half_1.y *= d_gelu(gelu_input_half_1.y);
float2 result;
__half2* result_half2 = reinterpret_cast<__half2*>(&result);
result_half2[0] = __float22half2_rn(output_half_0);
result_half2[1] = __float22half2_rn(output_half_1);
d_output_cast[row * row_stride + i * loop_stride + id] = result;
}
}
#endif
}
template <typename T>
void launch_bias_gelu(const T* input,
const T* bias,
T* output,
int intermediate_size,
int batch_size,
hipStream_t stream)
{
int iterations = (intermediate_size + 1023) / 1024;
int threads = intermediate_size / iterations / 4;
dim3 block_dims(threads);
dim3 grid_dims(batch_size);
hipLaunchKernelGGL(( fused_bias_gelu), dim3(grid_dims), dim3(block_dims), 0, stream, input, bias, output, intermediate_size);
}
template <typename T>
void launch_gelu(const T* input,
T* output,
int intermediate_size,
int batch_size,
hipStream_t stream)
{
int iterations = (intermediate_size + 1023) / 1024;
int threads = intermediate_size / iterations / 4;
dim3 block_dims(threads);
dim3 grid_dims(batch_size);
hipLaunchKernelGGL(( gelu_kernel), dim3(grid_dims), dim3(block_dims), 0, stream, input, output, intermediate_size);
}
template void launch_bias_gelu<float>(const float*, const float*, float*, int, int, hipStream_t);
template void launch_bias_gelu<__half>(const __half*,
const __half*,
__half*,
int,
int,
hipStream_t);
template void launch_gelu<float>(const float*, float*, int, int, hipStream_t);
template void launch_gelu<__half>(const __half*, __half*, int, int, hipStream_t);
template <typename T>
void launch_d_gelu(T* d_output,
const T* input,
const T* bias,
int intermediate_size,
int batch_size,
hipStream_t stream)
{
int iterations = (intermediate_size + 1023) / 1024;
int threads = intermediate_size / iterations / 4;
dim3 block_dims(threads);
dim3 grid_dims(batch_size);
hipLaunchKernelGGL(( d_gelu_func), dim3(grid_dims), dim3(block_dims), 0, stream, d_output, input, bias, intermediate_size);
}
template void launch_d_gelu<float>(float*, const float*, const float*, int, int, hipStream_t);
template void launch_d_gelu<__half>(__half*, const __half*, const __half*, int, int, hipStream_t);
#include "hip/hip_runtime.h"
#include "hip/general_kernels.h"
namespace cg = cooperative_groups;
//template <typename T>
//__global__ void column_sum_reduce(const T* __restrict__ inp,
// T* __restrict__ out,
// int rows,
// int width)
//{
// __shared__ float tile[TILE_DIM][TILE_DIM + 1];
//
// cg::thread_block b = cg::this_thread_block();
// cg::thread_block_tile<TILE_DIM> g = cg::tiled_partition<TILE_DIM>(b);
//
// int idx = blockDim.x * blockIdx.x + threadIdx.x;
//
// int y_stride = width * TILE_DIM;
//
// float localSum = 0;
//
// // Loop across matrix height
// if (idx < width) {
// int offset = threadIdx.y * width + idx;
// for (int r = threadIdx.y; r < rows; r += TILE_DIM) {
// localSum += (float)inp[offset];
// offset += y_stride;
// }
// }
//
// tile[threadIdx.x][threadIdx.y] = localSum;
//
// __syncthreads();
//
// // Sum the shared buffer.
// float sum = tile[threadIdx.y][threadIdx.x];
//
//#ifndef __STOCHASTIC_MODE__
// __syncthreads();
//#endif
//
// for (int i = 1; i < TILE_DIM; i <<= 1) sum += g.shfl_down(sum, i);
//
// if (threadIdx.x == 0) {
// int pos = blockIdx.x * TILE_DIM + threadIdx.y;
// if (pos < width) out[pos] = sum;
// }
//}
//template <typename T>
//void launch_fuse_transpose_bias_kernel(const T* inp,
// T* out,
// int rows,
// int cols,
// hipStream_t stream);
//
//template <>
//void launch_fuse_transpose_bias_kernel<float>(const float* inp,
// float* out,
// int rows,
// int cols,
// hipStream_t stream)
//{
// // assert(rows % TILE_DIM == 0);
// // assert(cols % TILE_DIM == 0);
//
// dim3 grid_dim((cols - 1) / TILE_DIM + 1);
// dim3 block_dim(TILE_DIM, TILE_DIM);
//
// hipLaunchKernelGGL(( column_sum_reduce<float>), dim3(grid_dim), dim3(block_dim), 0, stream, inp, out, rows, cols);
//}
//template <>
//void launch_fuse_transpose_bias_kernel<__half>(const __half* inp,
// __half* out,
// int rows,
// int cols,
// hipStream_t stream)
//{
// // assert(rows % TILE_DIM == 0);
// // assert(cols % TILE_DIM == 0);
//
// dim3 grid_dim((cols - 1) / TILE_DIM + 1);
// dim3 block_dim(TILE_DIM, TILE_DIM);
//
// hipLaunchKernelGGL(( column_sum_reduce<__half>), dim3(grid_dim), dim3(block_dim), 0, stream, inp, out, rows, cols);
//}
__global__ void fused_add2_kernel(const int N, float* out, const float* inp1, const float* inp2)
{
const float4* inp1_4 = reinterpret_cast<const float4*>(inp1);
const float4* inp2_4 = reinterpret_cast<const float4*>(inp2);
float4* out_4 = reinterpret_cast<float4*>(out);
CUDA_1D_KERNEL_LOOP(j, N)
{
float4 val;
float4 inp1_reg = inp1_4[j];
float4 inp2_reg = inp2_4[j];
val.x = inp1_reg.x + inp2_reg.x;
val.y = inp1_reg.y + inp2_reg.y;
val.z = inp1_reg.z + inp2_reg.z;
val.w = inp1_reg.w + inp2_reg.w;
out_4[j] = val;
}
}
__global__ void fused_add2_kernel(const int N, __half* out, const __half* inp1, const __half* inp2)
{
float2 inp1_4;
float2 inp2_4;
__half2* inp1_h = reinterpret_cast<__half2*>(&inp1_4);
__half2* inp2_h = reinterpret_cast<__half2*>(&inp2_4);
const float2* inp1_arr = reinterpret_cast<const float2*>(inp1);
const float2* inp2_arr = reinterpret_cast<const float2*>(inp2);
CUDA_1D_KERNEL_LOOP(j, N)
{
inp1_4 = inp1_arr[j];
inp2_4 = inp2_arr[j];
float2 inp1_h_f_0 = __half22float2(inp1_h[0]);
float2 inp1_h_f_1 = __half22float2(inp1_h[1]);
float2 inp2_h_f_0 = __half22float2(inp2_h[0]);
float2 inp2_h_f_1 = __half22float2(inp2_h[1]);
inp1_h_f_0.x += inp2_h_f_0.x;
inp1_h_f_0.y += inp2_h_f_0.y;
inp1_h_f_1.x += inp2_h_f_1.x;
inp1_h_f_1.y += inp2_h_f_1.y;
float2 val_f;
__half2* val_h = reinterpret_cast<__half2*>(&val_f);
val_h[0] = __float22half2_rn(inp1_h_f_0);
val_h[1] = __float22half2_rn(inp1_h_f_1);
float2* out_4 = reinterpret_cast<float2*>(out);
out_4[j] = val_f;
}
}
template <>
void launch_fused_add2<float>(float* out,
const float* inp1,
const float* inp2,
int batch_size,
int seq_length,
int hidden_dim,
hipStream_t& stream)
{
int total_count = batch_size * seq_length * hidden_dim / 4;
dim3 grid_dim = DS_GET_BLOCKS(total_count); //(batch_size * seq_length);
dim3 block_dim = DS_CUDA_NUM_THREADS; //(hidden_dim / 4);
hipLaunchKernelGGL(( fused_add2_kernel), dim3(grid_dim), dim3(block_dim), 0, stream, total_count, out, inp1, inp2);
}
template <>
void launch_fused_add2<__half>(__half* out,
const __half* inp1,
const __half* inp2,
int batch_size,
int seq_length,
int hidden_dim,
hipStream_t& stream)
{
int total_count = batch_size * seq_length * hidden_dim / 4;
dim3 grid_dim = DS_GET_BLOCKS(total_count); //(batch_size * seq_length);
dim3 block_dim = DS_CUDA_NUM_THREADS; //(hidden_dim / 4);
hipLaunchKernelGGL(( fused_add2_kernel), dim3(grid_dim), dim3(block_dim), 0, stream, total_count, out, inp1, inp2);
}
__global__ void fused_add3_kernel(float* out,
const float* inp1,
const float* inp2,
const float* inp3,
int size,
int row_stride)
{
int row = blockIdx.x;
int id = threadIdx.x;
const float4* inp1_4 = reinterpret_cast<const float4*>(inp1);
const float4* inp2_4 = reinterpret_cast<const float4*>(inp2);
const float4* inp3_4 = reinterpret_cast<const float4*>(inp3);
float4* out_4 = reinterpret_cast<float4*>(out);
float4 val;
float4 inp1_reg = inp1_4[row * row_stride + id];
float4 inp2_reg = inp2_4[row * row_stride + id];
float4 inp3_reg = inp3_4[row * row_stride + id];
val.x = inp1_reg.x + inp2_reg.x + inp3_reg.x;
val.y = inp1_reg.y + inp2_reg.y + inp3_reg.y;
val.z = inp1_reg.z + inp2_reg.z + inp3_reg.z;
val.w = inp1_reg.w + inp2_reg.w + inp3_reg.w;
out_4[row * row_stride + id] = val;
}
__global__ void fused_add3_kernel(__half* out,
const __half* inp1,
const __half* inp2,
const __half* inp3,
int size,
int row_stride)
{
int row = blockIdx.x;
int id = threadIdx.x;
const float2* inp1_arr = reinterpret_cast<const float2*>(inp1);
const float2* inp2_arr = reinterpret_cast<const float2*>(inp2);
const float2* inp3_arr = reinterpret_cast<const float2*>(inp3);
float2 inp1_4 = inp1_arr[row * row_stride + id];
float2 inp2_4 = inp2_arr[row * row_stride + id];
float2 inp3_4 = inp3_arr[row * row_stride + id];
__half2* inp1_h = reinterpret_cast<__half2*>(&inp1_4);
__half2* inp2_h = reinterpret_cast<__half2*>(&inp2_4);
__half2* inp3_h = reinterpret_cast<__half2*>(&inp3_4);
float2 inp1_h_f_0 = __half22float2(inp1_h[0]);
float2 inp1_h_f_1 = __half22float2(inp1_h[1]);
float2 inp2_h_f_0 = __half22float2(inp2_h[0]);
float2 inp2_h_f_1 = __half22float2(inp2_h[1]);
float2 inp3_h_f_0 = __half22float2(inp3_h[0]);
float2 inp3_h_f_1 = __half22float2(inp3_h[1]);
inp1_h_f_0.x += (inp2_h_f_0.x + inp3_h_f_0.x);
inp1_h_f_0.y += (inp2_h_f_0.y + inp3_h_f_0.y);
inp1_h_f_1.x += (inp2_h_f_1.x + inp3_h_f_1.x);
inp1_h_f_1.y += (inp2_h_f_1.y + inp3_h_f_1.y);
float2 val_f;
__half2* val_h = reinterpret_cast<__half2*>(&val_f);
val_h[0] = __float22half2_rn(inp1_h_f_0);
val_h[1] = __float22half2_rn(inp1_h_f_1);
float2* out_4 = reinterpret_cast<float2*>(out);
out_4[row * row_stride + id] = val_f;
}
template <>
void launch_fused_add3<float>(float* out,
const float* inp1,
const float* inp2,
const float* inp3,
int batch_size,
int seq_length,
int hidden_size,
hipStream_t& stream)
{
dim3 grid_dim(batch_size * seq_length);
dim3 block_dim(hidden_size / 4);
hipLaunchKernelGGL(( fused_add3_kernel), dim3(grid_dim), dim3(block_dim), 0, stream,
out, inp1, inp2, inp3, (batch_size * seq_length * hidden_size), hidden_size / 4);
}
template <>
void launch_fused_add3<__half>(__half* out,
const __half* inp1,
const __half* inp2,
const __half* inp3,
int batch_size,
int seq_length,
int hidden_size,
hipStream_t& stream)
{
dim3 grid_dim(batch_size * seq_length);
dim3 block_dim(hidden_size / 4);
hipLaunchKernelGGL(( fused_add3_kernel), dim3(grid_dim), dim3(block_dim), 0, stream,
out, inp1, inp2, inp3, (batch_size * seq_length * hidden_size), hidden_size / 4);
}
__global__ void fused_add4_kernel(float* out,
const float* inp1,
const float* inp2,
const float* inp3,
const float* inp4,
int size,
int row_stride)
{
int row = blockIdx.x;
int id = threadIdx.x;
const float4* inp1_4 = reinterpret_cast<const float4*>(inp1);
const float4* inp2_4 = reinterpret_cast<const float4*>(inp2);
const float4* inp3_4 = reinterpret_cast<const float4*>(inp3);
const float4* inp4_4 = reinterpret_cast<const float4*>(inp4);
float4* out_4 = reinterpret_cast<float4*>(out);
float4 val;
float4 inp1_reg = inp1_4[row * row_stride + id];
float4 inp2_reg = inp2_4[row * row_stride + id];
float4 inp3_reg = inp3_4[row * row_stride + id];
float4 inp4_reg = inp4_4[row * row_stride + id];
val.x = inp1_reg.x + inp2_reg.x + inp3_reg.x + inp4_reg.x;
val.y = inp1_reg.y + inp2_reg.y + inp3_reg.y + inp4_reg.y;
val.z = inp1_reg.z + inp2_reg.z + inp3_reg.z + inp4_reg.z;
val.w = inp1_reg.w + inp2_reg.w + inp3_reg.w + inp4_reg.w;
out_4[row * row_stride + id] = val;
}
__global__ void fused_add4_kernel(__half* out,
const __half* inp1,
const __half* inp2,
const __half* inp3,
const __half* inp4,
int size,
int row_stride)
{
int row = blockIdx.x;
int id = threadIdx.x;
const float2* inp1_arr = reinterpret_cast<const float2*>(inp1);
const float2* inp2_arr = reinterpret_cast<const float2*>(inp2);
const float2* inp3_arr = reinterpret_cast<const float2*>(inp3);
const float2* inp4_arr = reinterpret_cast<const float2*>(inp4);
float2 inp1_4 = inp1_arr[row * row_stride + id];
float2 inp2_4 = inp2_arr[row * row_stride + id];
float2 inp3_4 = inp3_arr[row * row_stride + id];
float2 inp4_4 = inp4_arr[row * row_stride + id];
__half2* inp1_h = reinterpret_cast<__half2*>(&inp1_4);
__half2* inp2_h = reinterpret_cast<__half2*>(&inp2_4);
__half2* inp3_h = reinterpret_cast<__half2*>(&inp3_4);
__half2* inp4_h = reinterpret_cast<__half2*>(&inp4_4);
float2 inp1_h_f_0 = __half22float2(inp1_h[0]);
float2 inp1_h_f_1 = __half22float2(inp1_h[1]);
float2 inp2_h_f_0 = __half22float2(inp2_h[0]);
float2 inp2_h_f_1 = __half22float2(inp2_h[1]);
float2 inp3_h_f_0 = __half22float2(inp3_h[0]);
float2 inp3_h_f_1 = __half22float2(inp3_h[1]);
float2 inp4_h_f_0 = __half22float2(inp4_h[0]);
float2 inp4_h_f_1 = __half22float2(inp4_h[1]);
inp1_h_f_0.x += (inp2_h_f_0.x + inp3_h_f_0.x + inp4_h_f_0.x);
inp1_h_f_0.y += (inp2_h_f_0.y + inp3_h_f_0.y + inp4_h_f_0.y);
inp1_h_f_1.x += (inp2_h_f_1.x + inp3_h_f_1.x + inp4_h_f_1.x);
inp1_h_f_1.y += (inp2_h_f_1.y + inp3_h_f_1.y + inp4_h_f_1.y);
float2 val_f;
__half2* val_h = reinterpret_cast<__half2*>(&val_f);
val_h[0] = __float22half2_rn(inp1_h_f_0);
val_h[1] = __float22half2_rn(inp1_h_f_1);
float2* out_4 = reinterpret_cast<float2*>(out);
out_4[row * row_stride + id] = val_f;
}
template <>
void launch_fused_add4<float>(float* out,
const float* inp1,
const float* inp2,
const float* inp3,
const float* inp4,
int batch_size,
int seq_length,
int hidden_size,
hipStream_t& stream)
{
dim3 grid_dim(batch_size * seq_length);
dim3 block_dim(hidden_size / 4);
hipLaunchKernelGGL(( fused_add4_kernel), dim3(grid_dim), dim3(block_dim), 0, stream,
out, inp1, inp2, inp3, inp4, (batch_size * seq_length * hidden_size), hidden_size / 4);
}
template <>
void launch_fused_add4<__half>(__half* out,
const __half* inp1,
const __half* inp2,
const __half* inp3,
const __half* inp4,
int batch_size,
int seq_length,
int hidden_size,
hipStream_t& stream)
{
dim3 grid_dim(batch_size * seq_length);
dim3 block_dim(hidden_size / 4);
hipLaunchKernelGGL(( fused_add4_kernel), dim3(grid_dim), dim3(block_dim), 0, stream,
out, inp1, inp2, inp3, inp4, (batch_size * seq_length * hidden_size), hidden_size / 4);
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment