Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
deepspeed
Commits
4acf0e01
Commit
4acf0e01
authored
Apr 26, 2023
by
aiss
Browse files
delete hip file
parent
7dd68788
Changes
83
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
0 additions
and
10331 deletions
+0
-10331
csrc/includes/softmax_hip.h
csrc/includes/softmax_hip.h
+0
-61
csrc/includes/strided_batch_gemm_hip.h
csrc/includes/strided_batch_gemm_hip.h
+0
-196
csrc/includes/type_shim_hip.h
csrc/includes/type_shim_hip.h
+0
-121
csrc/lamb/fused_lamb_hip_kernel.hip
csrc/lamb/fused_lamb_hip_kernel.hip
+0
-475
csrc/quantization/pt_binding_hip.cpp
csrc/quantization/pt_binding_hip.cpp
+0
-78
csrc/quantization/quantizer.hip
csrc/quantization/quantizer.hip
+0
-1039
csrc/transformer/cublas_wrappers.hip
csrc/transformer/cublas_wrappers.hip
+0
-404
csrc/transformer/dropout_kernels.hip
csrc/transformer/dropout_kernels.hip
+0
-870
csrc/transformer/ds_transformer_hip.cpp
csrc/transformer/ds_transformer_hip.cpp
+0
-1052
csrc/transformer/gelu_kernels.hip
csrc/transformer/gelu_kernels.hip
+0
-332
csrc/transformer/general_kernels.hip
csrc/transformer/general_kernels.hip
+0
-413
csrc/transformer/inference/csrc/apply_rotary_pos_emb.hip
csrc/transformer/inference/csrc/apply_rotary_pos_emb.hip
+0
-407
csrc/transformer/inference/csrc/dequantize.hip
csrc/transformer/inference/csrc/dequantize.hip
+0
-196
csrc/transformer/inference/csrc/gelu.hip
csrc/transformer/inference/csrc/gelu.hip
+0
-685
csrc/transformer/inference/csrc/layer_norm.hip
csrc/transformer/inference/csrc/layer_norm.hip
+0
-531
csrc/transformer/inference/csrc/normalize.hip
csrc/transformer/inference/csrc/normalize.hip
+0
-455
csrc/transformer/inference/csrc/pt_binding_hip.cpp
csrc/transformer/inference/csrc/pt_binding_hip.cpp
+0
-1781
csrc/transformer/inference/csrc/relu.hip
csrc/transformer/inference/csrc/relu.hip
+0
-65
csrc/transformer/inference/csrc/softmax.hip
csrc/transformer/inference/csrc/softmax.hip
+0
-490
csrc/transformer/inference/csrc/transform.hip
csrc/transformer/inference/csrc/transform.hip
+0
-680
No files found.
csrc/includes/softmax_hip.h
deleted
100644 → 0
View file @
7dd68788
// !!! This is a file automatically generated by hipify!!!
#pragma once
#include <hip/hip_runtime.h>
#include <hip/hip_fp16.h>
#include <stdio.h>
#include "custom_hip_layers.h"
#include <fstream>
using
namespace
std
;
template
<
typename
T
>
class
Softmax
{
public:
struct
Config
{
size_t
batchSize
;
size_t
heads
;
size_t
seq_length
;
size_t
prob_depth
;
float
temperature
;
bool
mem_alloc
;
Config
(
size_t
batch
,
size_t
h
,
size_t
seq
,
int
prob_size
=
0
,
bool
mem_alloc
=
false
)
:
batchSize
(
batch
),
heads
(
h
),
seq_length
(
seq
),
prob_depth
(
prob_size
),
temperature
(
1.0
),
mem_alloc
(
mem_alloc
)
{
}
};
Softmax
(
Config
config
)
:
config_
(
config
)
{}
~
Softmax
()
{}
void
Forward
(
int
bsz
,
T
*
vals
,
const
T
*
attn_mask
,
hipStream_t
&
stream
)
{
launch_attn_softmax
<
T
>
(
vals
,
attn_mask
,
bsz
,
config_
.
heads
,
config_
.
seq_length
,
stream
);
}
void
Backward
(
int
bsz
,
T
*
out_grad
,
const
T
*
soft_out
,
hipStream_t
stream
)
{
launch_attn_softmax_backward_v2
<
T
>
(
out_grad
,
soft_out
,
bsz
,
config_
.
heads
,
config_
.
seq_length
,
stream
);
}
inline
size_t
GetProbDepth
()
const
{
return
config_
.
prob_depth
;
}
inline
size_t
GetBatchSize
()
const
{
return
config_
.
batchSize
;
}
inline
size_t
GetNumHeads
()
const
{
return
config_
.
heads
;
}
inline
size_t
GetSeqLength
()
const
{
return
config_
.
seq_length
;
}
inline
void
SetSeqLength
(
size_t
seq_len
)
{
config_
.
seq_length
=
seq_len
;
}
private:
Config
config_
;
};
csrc/includes/strided_batch_gemm_hip.h
deleted
100644 → 0
View file @
7dd68788
// !!! This is a file automatically generated by hipify!!!
#pragma once
#include <hip/hip_runtime.h>
#include <hip/hip_fp16.h>
#include <stdio.h>
#include "context_hip.h"
template
<
typename
T
>
class
StridedBatchGemm
{
public:
struct
Config
{
int
batch_size
;
int
m
;
int
n
;
int
k
;
float
alpha
;
float
beta
;
rocblas_operation
op_A
;
rocblas_operation
op_B
;
std
::
array
<
int
,
3
>
gemm_algos
;
Config
(
int
batch
,
int
mm
,
int
nn
,
int
kk
,
float
param_alpha
,
float
param_beta
,
rocblas_operation
opA
,
rocblas_operation
opB
,
const
std
::
array
<
int
,
3
>&
algos
)
:
batch_size
(
batch
),
m
(
mm
),
n
(
nn
),
k
(
kk
),
alpha
(
param_alpha
),
beta
(
param_beta
),
op_A
(
opA
),
op_B
(
opB
),
gemm_algos
(
algos
)
{
}
void
SetConfig
(
int
mm
,
int
nn
,
int
kk
)
{
m
=
mm
;
n
=
nn
;
k
=
kk
;
}
};
StridedBatchGemm
(
const
Config
&
config
)
:
_config
(
config
)
{}
virtual
~
StridedBatchGemm
()
{}
void
Forward
(
int
bsz
,
T
*
output
,
const
T
*
_buffer_a
,
const
T
*
_buffer_b
,
rocblas_handle
handle
)
{
int
stride_a
=
_config
.
m
*
_config
.
k
;
int
stride_b
=
_config
.
n
*
_config
.
k
;
int
stride_c
=
_config
.
m
*
_config
.
n
;
cublas_strided_batched_gemm
(
handle
,
_config
.
m
,
_config
.
n
,
_config
.
k
,
&
_config
.
alpha
,
&
_config
.
beta
,
_buffer_a
,
_buffer_b
,
output
,
_config
.
op_A
,
_config
.
op_B
,
stride_a
,
stride_b
,
stride_c
,
bsz
,
#ifdef __HIP_PLATFORM_HCC__
rocblas_gemm_algo
(
_config
.
gemm_algos
[
0
]));
#else
cublasGemmAlgo_t
(
_config
.
gemm_algos
[
0
]));
#endif
}
void
ForwardPlusSave
(
T
*
output
,
const
T
*
_buffer_a
,
const
T
*
_buffer_b
,
rocblas_handle
handle
)
{
int
stride_a
=
_config
.
m
*
_config
.
k
;
int
stride_b
=
_config
.
n
*
_config
.
k
;
int
stride_c
=
_config
.
m
*
_config
.
n
;
cublas_strided_batched_gemm
(
handle
,
_config
.
m
,
_config
.
n
,
_config
.
k
,
&
_config
.
alpha
,
&
_config
.
beta
,
_buffer_a
,
_buffer_b
,
output
,
_config
.
op_A
,
_config
.
op_B
,
stride_a
,
stride_b
,
stride_c
,
_config
.
batch_size
,
#ifdef __HIP_PLATFORM_HCC__
rocblas_gemm_algo
(
_config
.
gemm_algos
[
0
]));
#else
cublasGemmAlgo_t
(
_config
.
gemm_algos
[
0
]));
#endif
k_buf
=
_buffer_a
;
q_buf
=
_buffer_b
;
}
void
Backward
(
int
bsz
,
const
T
*
d_output
,
const
T
*
_buffer_a
,
const
T
*
_buffer_b
,
rocblas_handle
handle
,
T
*
inpGradA
=
nullptr
,
T
*
inpGradB
=
nullptr
)
{
int
mb
=
(
_config
.
op_A
==
rocblas_operation_transpose
?
_config
.
k
:
_config
.
m
);
int
kb
=
(
_config
.
op_A
==
rocblas_operation_transpose
?
_config
.
m
:
_config
.
k
);
int
stride_a
=
mb
*
_config
.
n
;
int
stride_b
=
_config
.
n
*
kb
;
int
stride_c
=
_config
.
m
*
_config
.
k
;
// B need to transpose.
rocblas_operation
op_b
=
(
_config
.
op_B
==
rocblas_operation_transpose
?
rocblas_operation_none
:
rocblas_operation_transpose
);
// Calculate d_A.
cublas_strided_batched_gemm
(
handle
,
mb
,
kb
,
_config
.
n
,
&
_config
.
alpha
,
&
_config
.
beta
,
(
_config
.
op_A
==
rocblas_operation_transpose
?
_buffer_b
:
d_output
),
(
_config
.
op_A
==
rocblas_operation_transpose
?
d_output
:
_buffer_b
),
inpGradA
,
rocblas_operation_none
,
op_b
,
stride_a
,
stride_b
,
stride_c
,
bsz
,
#ifdef __HIP_PLATFORM_HCC__
rocblas_gemm_algo
(
_config
.
gemm_algos
[
1
]));
#else
cublasGemmAlgo_t
(
_config
.
gemm_algos
[
1
]));
#endif
// A need to transpose.
rocblas_operation
op_a
=
(
_config
.
op_A
==
rocblas_operation_transpose
?
rocblas_operation_none
:
rocblas_operation_transpose
);
stride_a
=
_config
.
m
*
_config
.
k
;
stride_b
=
_config
.
m
*
_config
.
n
;
stride_c
=
_config
.
n
*
_config
.
k
;
// Calculate d_B.
cublas_strided_batched_gemm
(
handle
,
_config
.
k
,
_config
.
n
,
_config
.
m
,
&
_config
.
alpha
,
&
_config
.
beta
,
_buffer_a
,
d_output
,
inpGradB
,
op_a
,
rocblas_operation_none
,
stride_a
,
stride_b
,
stride_c
,
bsz
,
#ifdef __HIP_PLATFORM_HCC__
rocblas_gemm_algo
(
_config
.
gemm_algos
[
2
]));
#else
cublasGemmAlgo_t
(
_config
.
gemm_algos
[
2
]));
#endif
}
inline
int
GetN
()
const
{
return
_config
.
k
;
}
inline
const
T
*
GetBufferA
()
const
{
return
k_buf
;
}
inline
const
T
*
GetBufferB
()
const
{
return
q_buf
;
}
inline
void
SetConfig
(
int
m
,
int
n
,
int
k
)
{
_config
.
SetConfig
(
m
,
n
,
k
);
}
private:
Config
_config
;
const
T
*
q_buf
;
const
T
*
k_buf
;
};
csrc/includes/type_shim_hip.h
deleted
100644 → 0
View file @
7dd68788
// !!! This is a file automatically generated by hipify!!!
#include "hip/hip_runtime.h"
/* Taken from NVIDIA/apex commit 855808f3fc268e9715d613f3c2e56469d8c986d8 */
#include <ATen/ATen.h>
// Forward/backward compatibility hack around
// https://github.com/pytorch/pytorch/commit/3aeb78079bcd68282fe9117088e138b77318e288
// pending more future-proof guidance from upstream.
// struct TypeShim
// {
// const at::Type& payload;
// TypeShim(const at::Type& type) : payload(type) {}
// // Enable trivial conversion to a const at::Type& for pre-3aeb78
// operator const at::Type&(){ return payload; };
// // Enable dispatch switch statements to take *this directly for post-3aeb78
// //operator at::ScalarType(){ return payload.; };
// };
#define DISPATCH_FLOAT_AND_HALF(TYPE, LEVEL, NAME, ...) \
switch (TYPE) { \
case at::ScalarType::Float: { \
using scalar_t_##LEVEL = float; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Half: { \
using scalar_t_##LEVEL = at::Half; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::BFloat16: { \
using scalar_t_##LEVEL = at::BFloat16; \
__VA_ARGS__; \
break; \
} \
default: AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
}
#define DISPATCH_DOUBLE_FLOAT_AND_HALF(TYPE, LEVEL, NAME, ...) \
switch (TYPE) { \
case at::ScalarType::Double: { \
using scalar_t_##LEVEL = double; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Float: { \
using scalar_t_##LEVEL = float; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Half: { \
using scalar_t_##LEVEL = at::Half; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::BFloat16: { \
using scalar_t_##LEVEL = at::BFloat16; \
__VA_ARGS__; \
break; \
} \
default: AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
}
#define DISPATCH_DOUBLE_AND_FLOAT(TYPE, LEVEL, NAME, ...) \
switch (TYPE) { \
case at::ScalarType::Double: { \
using scalar_t_##LEVEL = double; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Float: { \
using scalar_t_##LEVEL = float; \
__VA_ARGS__; \
break; \
} \
default: AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
}
template
<
typename
T
>
__device__
__forceinline__
T
reduce_block_into_lanes
(
T
*
x
,
T
val
,
int
lanes
=
1
,
bool
share_result
=
false
)
// lanes is intended to be <= 32.
{
int
tid
=
threadIdx
.
x
+
threadIdx
.
y
*
blockDim
.
x
;
int
blockSize
=
blockDim
.
x
*
blockDim
.
y
;
// blockSize is intended to be a multiple of 32.
if
(
blockSize
>=
64
)
{
x
[
tid
]
=
val
;
__syncthreads
();
}
#pragma unroll
for
(
int
i
=
(
blockSize
>>
1
);
i
>=
64
;
i
>>=
1
)
{
if
(
tid
<
i
)
x
[
tid
]
=
x
[
tid
]
+
x
[
tid
+
i
];
__syncthreads
();
}
T
final
;
if
(
tid
<
32
)
{
if
(
blockSize
>=
64
)
final
=
x
[
tid
]
+
x
[
tid
+
32
];
else
final
=
val
;
// __SYNCWARP();
#pragma unroll
for
(
int
i
=
16
;
i
>=
lanes
;
i
>>=
1
)
final
=
final
+
__shfl_down_sync
(
0xffffffff
,
final
,
i
);
}
if
(
share_result
)
{
if
(
tid
<
lanes
)
x
[
tid
]
=
final
;
// EpilogueOp
// Make sure the smem result is visible to all warps.
__syncthreads
();
}
return
final
;
}
csrc/lamb/fused_lamb_hip_kernel.hip
deleted
100644 → 0
View file @
7dd68788
// !!! This is a file automatically generated by hipify!!!
/* Copyright 2019 The Microsoft DeepSpeed Team */
#include <hip/hip_runtime.h>
#include <hip/hip_runtime.h>
#include <stdio.h>
#include <cmath>
#include "ATen/ATen.h"
#include "ATen/TensorUtils.h"
#include "ATen/hip/HIPContext.h"
#include "ATen/hip/detail/IndexUtils.cuh"
//#include "ATen/Type.h"
#include "ATen/AccumulateType.h"
#include <iostream>
//#include <helper_functions.h>
#if defined(__HIP_PLATFORM_HCC__) && HIP_VERSION > 305
#include <hip/hip_cooperative_groups.h>
#else
#include <cooperative_groups.h>
#endif
#include <hip/hip_runtime_api.h>
#include <stdio.h>
namespace cg = cooperative_groups;
// Utility class used to avoid linker errors with extern
// unsized shared memory arrays with templated type
namespace {
// This is the un-specialized struct. Note that we prevent instantiation of this
// struct by putting an undefined symbol in the function body so it won't compile.
template <typename T>
struct SharedMemory {
// Ensure that we won't compile any un-specialized types
__device__ inline operator T*()
{
#ifndef _WIN32
extern __device__ void error(void);
error();
#endif
return NULL;
}
};
template <>
struct SharedMemory<float> {
__device__ inline operator float*()
{
HIP_DYNAMIC_SHARED( float, s_float)
return s_float;
}
};
template <>
struct SharedMemory<double> {
__device__ inline operator double*()
{
HIP_DYNAMIC_SHARED( double, s_double)
return s_double;
}
};
} // namespace
#include "type_shim_hip.h"
typedef enum {
ADAM_MODE_0 = 0, // eps under square root
ADAM_MODE_1 = 1 // eps outside square root
} adamMode_t;
// s_a and s_b are in shared memory
// g_a and g_b are in shared memory
template <typename T, int blockSize>
__device__ void reduce_block_in_shared_memory(T* s_a, T* s_b, T* g_a, T* g_b)
{
// Handle to thread block group
cg::thread_block cta = cg::this_thread_block();
// perform block reduction in shared memory,
unsigned int tid = cta.thread_rank();
T a_sum = s_a[tid];
T b_sum = s_b[tid];
cg::sync(cta);
// do reduction in shared mem
if ((blockSize >= 512) && (tid < 256)) {
s_a[tid] = a_sum = a_sum + s_a[tid + 256];
s_b[tid] = b_sum = b_sum + s_b[tid + 256];
}
cg::sync(cta);
if ((blockSize >= 256) && (tid < 128)) {
s_a[tid] = a_sum = a_sum + s_a[tid + 128];
s_b[tid] = b_sum = b_sum + s_b[tid + 128];
}
cg::sync(cta);
if ((blockSize >= 128) && (tid < 64)) {
s_a[tid] = a_sum = a_sum + s_a[tid + 64];
s_b[tid] = b_sum = b_sum + s_b[tid + 64];
}
cg::sync(cta);
#if (__CUDA_ARCH__ >= 300)
if (tid < 32) {
cg::coalesced_group active = cg::coalesced_threads();
// Fetch final intermediate sum from 2nd warp
if (blockSize >= 64) {
a_sum = a_sum + s_a[tid + 32];
b_sum = b_sum + s_b[tid + 32];
}
// Reduce final warp using shuffle
for (int offset = warpSize / 2; offset > 0; offset /= 2) {
a_sum += active.shfl_down(a_sum, offset);
b_sum += active.shfl_down(b_sum, offset);
}
}
#else
if ((blockSize >= 64) && (tid < 32)) {
s_a[tid] = a_sum = a_sum + s_a[tid + 32];
s_b[tid] = b_sum = b_sum + s_b[tid + 32];
}
cg::sync(cta);
if ((blockSize >= 32) && (tid < 16)) {
s_a[tid] = a_sum = a_sum + s_a[tid + 16];
s_b[tid] = b_sum = b_sum + s_b[tid + 16];
}
cg::sync(cta);
if ((blockSize >= 16) && (tid < 8)) {
s_a[tid] = a_sum = a_sum + s_a[tid + 8];
s_b[tid] = b_sum = b_sum + s_b[tid + 8];
}
cg::sync(cta);
if ((blockSize >= 8) && (tid < 4)) {
s_a[tid] = a_sum = a_sum + s_a[tid + 4];
s_b[tid] = b_sum = b_sum + s_b[tid + 4];
}
cg::sync(cta);
if ((blockSize >= 4) && (tid < 2)) {
s_a[tid] = a_sum = a_sum + s_a[tid + 2];
s_b[tid] = b_sum = b_sum + s_b[tid + 2];
}
cg::sync(cta);
if ((blockSize >= 2) && (tid < 1)) {
s_a[tid] = a_sum = a_sum + s_a[tid + 1];
s_b[tid] = b_sum = b_sum + s_b[tid + 1];
}
cg::sync(cta);
#endif
// write result for this block to global mem
if (tid == 0) {
g_a[blockIdx.x] = (T)a_sum;
g_b[blockIdx.x] = (T)b_sum;
}
}
template <typename T, int blockSize>
__device__ void reduce_two_vectors_in_register(T a, T b, T* g_a, T* g_b)
{
const int threadIdInBlock = cg::this_thread_block().thread_rank();
T* s_a = SharedMemory<T>();
T* s_b = SharedMemory<T>() + cg::this_thread_block().size();
s_a[threadIdInBlock] = a;
s_b[threadIdInBlock] = b;
reduce_block_in_shared_memory<T, blockSize>(s_a, s_b, g_a, g_b);
}
template <typename T, typename GRAD_T, int blockSize>
__global__ void lamb_cuda_kernel_part1(
T* __restrict__ p,
GRAD_T* __restrict__ p_copy, // For mixed precision training, pass NULL if not needed
T* __restrict__ m,
T* __restrict__ v,
const GRAD_T* __restrict__ g,
const float b1,
const float b2,
const float eps,
const float grad_scale,
const float step_size,
const size_t tsize,
adamMode_t mode,
const float decay,
T* __restrict__ w_l2_i,
T* __restrict__ u_l2_i)
{
// Assuming 2D grids and 2D blocks
const int blockId = gridDim.x * blockIdx.y + blockIdx.x;
const int threadsPerBlock = blockDim.x * blockDim.y;
const int threadIdInBlock = cg::this_thread_block().thread_rank();
const int i = (blockId * threadsPerBlock + threadIdInBlock);
const int totThreads = gridDim.x * gridDim.y * threadsPerBlock;
T reg_w = 0;
T reg_u = 0;
for (int j = i; j < tsize; j += totThreads) {
T scaled_grad = g[j] / grad_scale;
T pj = p[j];
m[j] = b1 * m[j] + (1 - b1) * scaled_grad;
v[j] = b2 * v[j] + (1 - b2) * scaled_grad * scaled_grad;
float denom;
if (mode == ADAM_MODE_0)
denom = sqrtf(v[j] + eps);
else // Mode 1
denom = sqrtf(v[j]) + eps;
T update = (m[j] / denom) + (decay * p[j]);
reg_u += update * update;
reg_w += pj * pj;
}
reduce_two_vectors_in_register<T, blockSize>(reg_w, reg_u, w_l2_i, u_l2_i);
}
template <typename T, typename GRAD_T, int blockSize>
__global__ void lamb_cuda_kernel_part2(const size_t tsize, T* __restrict__ g_a, T* __restrict__ g_b)
{
T* s_a = SharedMemory<T>();
T* s_b = SharedMemory<T>() + cg::this_thread_block().size();
const int threadIdInBlock = cg::this_thread_block().thread_rank();
s_a[threadIdInBlock] = g_a[threadIdInBlock];
s_b[threadIdInBlock] = g_b[threadIdInBlock];
if (threadIdInBlock >= tsize) {
s_a[threadIdInBlock] = 0.0;
s_b[threadIdInBlock] = 0.0;
}
reduce_block_in_shared_memory<T, blockSize>(s_a, s_b, g_a, g_b);
}
template <typename T, typename GRAD_T>
__global__ void lamb_cuda_kernel_part3(
T* __restrict__ p,
GRAD_T* __restrict__ p_copy, // For mixed precision training, pass NULL if not needed
T* __restrict__ m,
T* __restrict__ v,
const GRAD_T* __restrict__ g,
const float b1,
const float b2,
const float max_coeff,
const float min_coeff,
const float eps,
const float grad_scale,
const float step_size,
const size_t tsize,
adamMode_t mode,
const float decay,
T* __restrict__ w_l2_i,
T* __restrict__ u_l2_i,
T* __restrict__ lamb_coeff_val)
{
// Assuming 2D grids and 2D blocks
const int blockId = gridDim.x * blockIdx.y + blockIdx.x;
const int threadsPerBlock = blockDim.x * blockDim.y;
const int threadIdInBlock = cg::this_thread_block().thread_rank();
const int i = (blockId * threadsPerBlock + threadIdInBlock);
const int totThreads = gridDim.x * gridDim.y * threadsPerBlock;
T reg_w = sqrtf(w_l2_i[0]);
T reg_u = sqrtf(u_l2_i[0]);
float lamb_coeff = 1.0;
if (reg_w != 0 && reg_u != 0) {
lamb_coeff = reg_w / reg_u;
if (lamb_coeff > max_coeff) { lamb_coeff = max_coeff; }
if (lamb_coeff < min_coeff) { lamb_coeff = min_coeff; }
}
if (blockId == 0 && threadIdInBlock == 0) {
lamb_coeff_val[0] = lamb_coeff;
// printf("Cuda Lamb Coeff is %.6f \n",lamb_coeff);
}
for (int j = i; j < tsize; j += totThreads) {
T pj = (float)p[j];
T mj = m[j];
T vj = v[j];
float denom;
if (mode == ADAM_MODE_0)
denom = sqrtf(vj + eps);
else // Mode 1
denom = sqrtf(vj) + eps;
T update = (mj / denom) + (decay * pj);
pj = pj - (step_size * lamb_coeff * update);
p[j] = pj;
if (p_copy != NULL) p_copy[j] = (GRAD_T)pj;
}
}
void fused_lamb_cuda(at::Tensor& p,
at::Tensor& p_copy,
at::Tensor& m,
at::Tensor& v,
at::Tensor& g,
float lr,
float beta1,
float beta2,
float max_coeff,
float min_coeff,
float eps,
float grad_scale,
int step,
int mode,
int bias_correction,
float decay,
at::Tensor& w_l2_i,
at::Tensor& u_l2_i,
at::Tensor& lamb_coeff)
{
// using namespace at;
// Get tensor size
int tsize = p.numel();
// Determine #threads and #blocks
const int threadsPerBlock = 512;
int num_blocks = (tsize + threadsPerBlock - 1) / threadsPerBlock;
if (num_blocks > 512) num_blocks = 512;
int smemsize = 0;
if (p.type().scalarType() == at::ScalarType::Double)
smemsize = 2 * threadsPerBlock * sizeof(double);
else
smemsize = 2 * threadsPerBlock * sizeof(float);
const dim3 blocks(num_blocks);
const dim3 threads(threadsPerBlock);
AT_ASSERTM(at::cuda::detail::canUse32BitIndexMath(p),
"parameter tensor is too large to be indexed with int32");
// Constants
float step_size = 0;
if (bias_correction == 1) {
const float bias_correction1 = 1 - ::pow(beta1, step);
const float bias_correction2 = 1 - ::pow(beta2, step);
step_size = lr * std::sqrt(bias_correction2) / bias_correction1;
} else {
step_size = lr;
}
hipStream_t stream = at::hip::getCurrentHIPStreamMasqueradingAsCUDA();
if (g.type().scalarType() == at::ScalarType::Half) {
// all other values should be fp32 for half gradients
AT_ASSERTM(p.type().scalarType() == at::ScalarType::Float,
"expected parameter to be of float type");
// dispatch is done on the gradient type
using namespace at; // prevents "toString is undefined" errors
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
g.scalar_type(), "lamb_cuda_kernel", ([&] {
using accscalar_t = at::acc_type<scalar_t, true>;
hipLaunchKernelGGL(( lamb_cuda_kernel_part1<accscalar_t, scalar_t, threadsPerBlock>)
, dim3(blocks), dim3(threadsPerBlock), smemsize, stream,
p.data<accscalar_t>(),
p_copy.numel() ? p_copy.data<scalar_t>() : NULL,
m.data<accscalar_t>(),
v.data<accscalar_t>(),
g.data<scalar_t>(),
beta1,
beta2,
eps,
grad_scale,
step_size,
tsize,
(adamMode_t)mode,
decay,
w_l2_i.data<accscalar_t>(),
u_l2_i.data<accscalar_t>());
hipLaunchKernelGGL(( lamb_cuda_kernel_part2<accscalar_t, scalar_t, threadsPerBlock>)
, dim3(1), dim3(threadsPerBlock), smemsize, stream,
num_blocks, w_l2_i.data<accscalar_t>(), u_l2_i.data<accscalar_t>());
hipLaunchKernelGGL(( lamb_cuda_kernel_part3<accscalar_t, scalar_t>)
, dim3(blocks), dim3(threadsPerBlock), smemsize, stream,
p.data<accscalar_t>(),
p_copy.numel() ? p_copy.data<scalar_t>() : NULL,
m.data<accscalar_t>(),
v.data<accscalar_t>(),
g.data<scalar_t>(),
beta1,
beta2,
max_coeff,
min_coeff,
eps,
grad_scale,
step_size,
tsize,
(adamMode_t)mode,
decay,
w_l2_i.data<accscalar_t>(),
u_l2_i.data<accscalar_t>(),
lamb_coeff.data<accscalar_t>());
}));
} else {
using namespace at;
AT_DISPATCH_FLOATING_TYPES(
g.scalar_type(), "lamb_cuda_kernel", ([&] {
hipLaunchKernelGGL(( lamb_cuda_kernel_part1<scalar_t, scalar_t, threadsPerBlock>)
, dim3(blocks), dim3(threadsPerBlock), smemsize, stream,
p.data<scalar_t>(),
NULL, // don't output p_copy for fp32, it's wasted write
m.data<scalar_t>(),
v.data<scalar_t>(),
g.data<scalar_t>(),
beta1,
beta2,
eps,
grad_scale,
step_size,
tsize,
(adamMode_t)mode,
decay,
w_l2_i.data<scalar_t>(),
u_l2_i.data<scalar_t>());
hipLaunchKernelGGL(( lamb_cuda_kernel_part2<scalar_t, scalar_t, threadsPerBlock>)
, dim3(1), dim3(threadsPerBlock), smemsize, stream,
num_blocks, w_l2_i.data<scalar_t>(), u_l2_i.data<scalar_t>());
hipLaunchKernelGGL(( lamb_cuda_kernel_part3<scalar_t, scalar_t>)
, dim3(blocks), dim3(threadsPerBlock), smemsize, stream,
p.data<scalar_t>(),
NULL, // don't output p_copy for fp32, it's wasted write
m.data<scalar_t>(),
v.data<scalar_t>(),
g.data<scalar_t>(),
beta1,
beta2,
max_coeff,
min_coeff,
eps,
grad_scale,
step_size,
tsize,
(adamMode_t)mode,
decay,
w_l2_i.data<scalar_t>(),
u_l2_i.data<scalar_t>(),
lamb_coeff.data<scalar_t>());
}));
}
C10_HIP_CHECK(hipGetLastError());
}
// template __device__ void reduce_two_vectors_in_register<float,512>(float a, float b, float* g_a,
// float* g_b, cg::grid_group &cgg);
csrc/quantization/pt_binding_hip.cpp
deleted
100644 → 0
View file @
7dd68788
// !!! This is a file automatically generated by hipify!!!
#include <ATen/hip/HIPContext.h>
#include <torch/extension.h>
#include <vector>
#include "custom_hip_layers.h"
template
<
typename
T
>
at
::
Tensor
ds_quantize
(
at
::
Tensor
&
vals
,
int
groups
,
int
bits
)
{
auto
t_size
=
vals
.
sizes
();
int
size
=
1
;
for
(
auto
dim
:
t_size
)
size
*=
dim
;
if
((((
size
/
groups
)
-
1
)
/
4096
+
1
)
<=
MAX_REG
)
{
launch_quantize_kernel
(
(
T
*
)
vals
.
data_ptr
(),
size
,
groups
,
bits
,
at
::
hip
::
getCurrentHIPStreamMasqueradingAsCUDA
());
}
return
vals
;
}
template
<
typename
T
>
at
::
Tensor
ds_sr_quantize
(
at
::
Tensor
&
vals
,
int
groups
,
int
bits
)
{
auto
t_size
=
vals
.
sizes
();
int
size
=
1
;
for
(
auto
dim
:
t_size
)
size
*=
dim
;
if
(((
size
/
groups
)
/
4
/
1024
)
<=
256
)
{
launch_sr_quantize_kernel
(
(
T
*
)
vals
.
data_ptr
(),
size
,
groups
,
bits
,
at
::
hip
::
getCurrentHIPStreamMasqueradingAsCUDA
());
}
return
vals
;
}
template
<
typename
T
>
at
::
Tensor
ds_quantize_asym
(
at
::
Tensor
&
vals
,
int
groups
,
int
bits
)
{
auto
t_size
=
vals
.
sizes
();
int
size
=
1
;
for
(
auto
dim
:
t_size
)
size
*=
dim
;
if
((((
size
/
groups
)
-
1
)
/
4096
+
1
)
<=
MAX_REG
)
{
launch_quantize_kernel_asym
(
(
T
*
)
vals
.
data_ptr
(),
size
,
groups
,
bits
,
at
::
hip
::
getCurrentHIPStreamMasqueradingAsCUDA
());
}
return
vals
;
}
template
<
typename
T
>
at
::
Tensor
ds_sr_quantize_asym
(
at
::
Tensor
&
vals
,
int
groups
,
int
bits
)
{
auto
t_size
=
vals
.
sizes
();
int
size
=
1
;
for
(
auto
dim
:
t_size
)
size
*=
dim
;
if
(((
size
/
groups
)
/
4
/
1024
)
<=
256
)
{
launch_sr_quantize_kernel_asym
(
(
T
*
)
vals
.
data_ptr
(),
size
,
groups
,
bits
,
at
::
hip
::
getCurrentHIPStreamMasqueradingAsCUDA
());
}
return
vals
;
}
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"ds_quantize_fp32"
,
&
ds_quantize
<
float
>
,
"DeepSpeed Quantize with fp32 (CUDA)"
);
m
.
def
(
"ds_quantize_fp16"
,
&
ds_quantize
<
__half
>
,
"DeepSpeed Quantize with fp16 (CUDA)"
);
m
.
def
(
"ds_sr_quantize_fp32"
,
&
ds_sr_quantize
<
float
>
,
"DeepSpeed Quantize with fp32 (CUDA)"
);
m
.
def
(
"ds_sr_quantize_fp16"
,
&
ds_sr_quantize
<
__half
>
,
"DeepSpeed Quantize with fp16 (CUDA)"
);
m
.
def
(
"ds_quantize_asym_fp32"
,
&
ds_quantize_asym
<
float
>
,
"DeepSpeed Quantize with fp32 (CUDA)"
);
m
.
def
(
"ds_quantize_asym_fp16"
,
&
ds_quantize_asym
<
__half
>
,
"DeepSpeed Quantize with fp16 (CUDA)"
);
m
.
def
(
"ds_sr_quantize_asym_fp32"
,
&
ds_sr_quantize_asym
<
float
>
,
"DeepSpeed Quantize with fp32 (CUDA)"
);
m
.
def
(
"ds_sr_quantize_asym_fp16"
,
&
ds_sr_quantize_asym
<
__half
>
,
"DeepSpeed Quantize with fp16 (CUDA)"
);
}
csrc/quantization/quantizer.hip
deleted
100644 → 0
View file @
7dd68788
// !!! This is a file automatically generated by hipify!!!
#include "hip/hip_runtime.h"
#include <math.h>
#include "custom_hip_layers.h"
namespace cg = cooperative_groups;
__global__ void quantize_kernel(__half* vals, int group_size, int num_bits)
{
#if __CUDA_ARCH__ >= 700 || defined(__HIP_PLATFORM_HCC__)
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<32> g = cg::tiled_partition<32>(b);
int gid = threadIdx.x >> 5;
int lane = threadIdx.x & 0x1f;
int warp_num = blockDim.x >> 5;
int id = threadIdx.x;
float2* vals_cast = reinterpret_cast<float2*>(vals);
float2 data[MAX_REG];
int group_id = blockIdx.x;
{
int group_index = id;
int reg_count = 0;
int offset = group_id * group_size;
float max = -10000.0;
while (group_index < group_size && reg_count < MAX_REG) {
data[reg_count] = vals_cast[offset + group_index];
__half* data_h = reinterpret_cast<__half*>(&data[reg_count]);
if (abs((float)data_h[0]) > max) max = abs((float)data_h[0]);
if (abs((float)data_h[1]) > max) max = abs((float)data_h[1]);
if (abs((float)data_h[2]) > max) max = abs((float)data_h[2]);
if (abs((float)data_h[3]) > max) max = abs((float)data_h[3]);
group_index += blockDim.x;
reg_count++;
}
#pragma unroll
for (int i = 1; i < WARP_SIZE; i <<= 1) {
auto temp = g.shfl_xor(max, i);
if (max < temp) max = temp;
}
__shared__ float partialMax[WARP_SIZE];
if (lane == 0) partialMax[gid] = max;
b.sync();
if (lane < warp_num) max = partialMax[lane];
#pragma unroll
for (int i = 1; i < WARP_SIZE; i <<= 1) {
auto temp = g.shfl_down(max, i);
if (max < temp) max = temp;
}
max = g.shfl(max, 0);
float q_scale = (1 << num_bits) / (2 * max + 1e-5);
float q_scale_inv = 1 / q_scale;
for (int i = 0; i < reg_count; i++) {
group_index = i * blockDim.x + id;
if (group_index < group_size) {
__half2* data_h = reinterpret_cast<__half2*>(&data[i]);
float2 q_data[2];
q_data[0] = __half22float2(data_h[0]);
q_data[1] = __half22float2(data_h[1]);
float2 q_data_int[2];
q_data_int[0].x = roundf(q_data[0].x * q_scale);
q_data_int[0].y = roundf(q_data[0].y * q_scale);
q_data_int[1].x = roundf(q_data[1].x * q_scale);
q_data_int[1].y = roundf(q_data[1].y * q_scale);
q_data_int[0].x *= q_scale_inv;
q_data_int[0].y *= q_scale_inv;
q_data_int[1].x *= q_scale_inv;
q_data_int[1].y *= q_scale_inv;
data_h[0] = __float22half2_rn(q_data_int[0]);
data_h[1] = __float22half2_rn(q_data_int[1]);
vals_cast[offset + group_index] = data[i];
}
}
}
#endif
}
__global__ void quantize_kernel(float* vals, int group_size, int num_bits)
{
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<32> g = cg::tiled_partition<32>(b);
int gid = threadIdx.x >> 5;
int lane = threadIdx.x & 0x1f;
int warp_num = blockDim.x >> 5;
int id = threadIdx.x;
float4* vals_cast = reinterpret_cast<float4*>(vals);
float4 data[MAX_REG];
int bid = blockIdx.x;
int group_index = bid * group_size + id;
int reg_count = 0;
float max = -10000.0;
while (id < group_size && reg_count < MAX_REG) {
float4 data_reg = vals_cast[group_index];
data[reg_count] = data_reg;
if (abs(data_reg.x) > max) max = abs(data_reg.x);
if (abs(data_reg.y) > max) max = abs(data_reg.y);
if (abs(data_reg.z) > max) max = abs(data_reg.z);
if (abs(data_reg.w) > max) max = abs(data_reg.w);
group_index += blockDim.x;
id += blockDim.x;
reg_count++;
}
id = threadIdx.x;
#pragma unroll
for (int i = 1; i < WARP_SIZE; i <<= 1) {
auto temp = g.shfl_xor(max, i);
if (max < temp) max = temp;
}
__shared__ float partialMax[WARP_SIZE];
if (lane == 0) partialMax[gid] = max;
b.sync();
if (lane < warp_num) max = partialMax[lane];
b.sync();
#pragma unroll
for (int i = 1; i < warp_num; i <<= 1) {
auto temp = g.shfl_down(max, i);
if (max < temp) max = temp;
}
max = g.shfl(max, 0);
float q_scale = (1 << num_bits) / (2 * max + 1e-5);
float q_scale_inv = 1 / q_scale;
for (int i = 0; i < reg_count; i++) {
group_index = i * blockDim.x + id;
if (group_index < group_size) {
float4 q_data;
q_data = data[i];
float4 q_data_int;
q_data_int.x = roundf(q_data.x * q_scale);
q_data_int.y = roundf(q_data.y * q_scale);
q_data_int.w = roundf(q_data.w * q_scale);
q_data_int.z = roundf(q_data.z * q_scale);
q_data.x = q_data_int.x * q_scale_inv;
q_data.y = q_data_int.y * q_scale_inv;
q_data.w = q_data_int.w * q_scale_inv;
q_data.z = q_data_int.z * q_scale_inv;
vals_cast[group_index + bid * group_size] = q_data;
}
}
}
template <typename T>
void launch_quantize_kernel(T* vals,
int total_count,
int group_num,
int num_bits,
hipStream_t stream)
{
dim3 grid_dim(group_num);
dim3 block_dim(1024);
hipLaunchKernelGGL(( quantize_kernel), dim3(grid_dim), dim3(block_dim), 0, stream,
vals, (total_count / group_num) / 4, num_bits);
}
template void launch_quantize_kernel(float* vals,
int total_count,
int group_num,
int num_bits,
hipStream_t stream);
template void launch_quantize_kernel(__half* vals,
int total_count,
int group_num,
int num_bits,
hipStream_t stream);
__global__ void sr_quantize_kernel(__half* vals,
int token_size,
int token_num,
int num_bits,
std::pair<uint64_t, uint64_t> seed)
{
#if __CUDA_ARCH__ >= 700 || defined(__HIP_PLATFORM_HCC__)
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<32> g = cg::tiled_partition<32>(b);
int gid = threadIdx.x >> 5;
int lane = threadIdx.x & 0x1f;
int warp_num = blockDim.x >> 5;
int idx = blockIdx.x * blockDim.x + threadIdx.x;
float2* vals_cast = reinterpret_cast<float2*>(vals);
__half2 data_low[128];
__half2 data_high[128];
int bid = blockIdx.x;
hiprandStatePhilox4_32_10_t state;
hiprand_init(seed.first, idx, seed.second, &state);
unsigned int tid = threadIdx.x;
int reg_count = 0;
int offset = bid * token_size;
int group_index = bid * token_size + tid;
int total_count = token_size * token_num;
if (group_index < total_count) {
// float min = 10000.0;
float max = -10000.0;
while (tid < token_size) {
float2 data = vals_cast[offset + tid];
__half2* data_h = reinterpret_cast<__half2*>(&data);
data_low[reg_count] = data_h[0];
data_high[reg_count] = data_h[1];
float2 data_f[2];
data_f[0] = __half22float2(data_h[0]);
data_f[1] = __half22float2(data_h[1]);
if (abs((float)data_f[0].x) > max) max = abs((float)data_f[0].x);
if (abs((float)data_f[0].y) > max) max = abs((float)data_f[0].y);
if (abs((float)data_f[1].x) > max) max = abs((float)data_f[1].x);
if (abs((float)data_f[1].y) > max) max = abs((float)data_f[1].y);
tid += blockDim.x;
reg_count++;
}
#pragma unroll
for (int i = 1; i < WARP_SIZE; i <<= 1) {
auto temp = g.shfl_xor(max, i);
if (max < temp) max = temp;
}
__shared__ float partialMax[WARP_SIZE];
if (lane == 0) partialMax[gid] = max;
b.sync();
if (lane < warp_num) max = partialMax[lane];
#pragma unroll
for (int i = 1; i < warp_num; i <<= 1) {
auto temp = g.shfl_down(max, i);
if (max < temp) max = temp;
}
max = g.shfl(max, 0);
float q_scale_val = (float)(1 << num_bits) / (max * 2 + 1e-5);
float high_q = (float)((1 << (num_bits - 1)) - 1);
float low_q = (float)(-((1 << (num_bits - 1))));
for (int i = 0; i < reg_count; i++) {
int token_index = i * blockDim.x + threadIdx.x;
if (token_index < token_size) {
float2 data_f[2];
data_f[0] = __half22float2(data_low[i]);
data_f[1] = __half22float2(data_high[i]);
float2 q_data_int[2];
q_data_int[0].x = (float)((int)(data_f[0].x * q_scale_val));
q_data_int[0].y = (float)((int)(data_f[0].y * q_scale_val));
q_data_int[1].x = (float)((int)(data_f[1].x * q_scale_val));
q_data_int[1].y = (float)((int)(data_f[1].y * q_scale_val));
// Stochastic rounding
float4 rand = hiprand_uniform4(&state);
float q_error[4];
q_error[0] = abs(data_f[0].x - (q_data_int[0].x / q_scale_val)) * q_scale_val;
q_error[1] = abs(data_f[0].y - (q_data_int[0].y / q_scale_val)) * q_scale_val;
q_error[2] = abs(data_f[1].x - (q_data_int[1].x / q_scale_val)) * q_scale_val;
q_error[3] = abs(data_f[1].y - (q_data_int[1].y / q_scale_val)) * q_scale_val;
q_data_int[0].x =
(rand.x < q_error[0] && q_data_int[0].x > low_q && q_data_int[0].x < high_q)
? (q_data_int[0].x + (data_f[0].x > 0 ? 1 : -1))
: q_data_int[0].x;
q_data_int[0].y =
(rand.y < q_error[1] && q_data_int[0].y > low_q && q_data_int[0].y < high_q)
? (q_data_int[0].y + (data_f[0].y > 0 ? 1 : -1))
: q_data_int[0].y;
q_data_int[1].x =
(rand.w < q_error[2] && q_data_int[1].x > low_q && q_data_int[1].x < high_q)
? (q_data_int[1].x + (data_f[1].x > 0 ? 1 : -1))
: q_data_int[1].x;
q_data_int[1].y =
(rand.z < q_error[3] && q_data_int[1].y > low_q && q_data_int[1].y < high_q)
? (q_data_int[1].y + (data_f[1].y > 0 ? 1 : -1))
: q_data_int[1].y;
data_f[0].x = q_data_int[0].x / q_scale_val;
data_f[0].y = q_data_int[0].y / q_scale_val;
data_f[1].x = q_data_int[1].x / q_scale_val;
data_f[1].y = q_data_int[1].y / q_scale_val;
float2 result;
__half2* result_h = reinterpret_cast<__half2*>(&result);
result_h[0] = __float22half2_rn(data_f[0]);
result_h[1] = __float22half2_rn(data_f[1]);
vals_cast[offset + token_index] = result;
}
}
}
#endif
}
__global__ void sr_quantize_kernel(float* vals,
int token_size,
int token_num,
int num_bits,
std::pair<uint64_t, uint64_t> seed)
{
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<32> g = cg::tiled_partition<32>(b);
int gid = threadIdx.x >> 5;
int lane = threadIdx.x & 0x1f;
int warp_num = blockDim.x >> 5;
int id = threadIdx.x;
int idx = blockIdx.x * blockDim.x + id;
float4* vals_cast = reinterpret_cast<float4*>(vals);
float4 data[128];
int bid = blockIdx.x;
int tid = threadIdx.x;
hiprandStatePhilox4_32_10_t state;
hiprand_init(seed.first, idx, seed.second, &state);
int group_index = bid * token_size + threadIdx.x;
int reg_count = 0;
int total_count = token_size * token_num;
if (group_index < total_count) {
// float min = 10000.0;
float max = -10000.0;
while (tid < token_size) {
data[reg_count] = vals_cast[group_index];
if (abs(data[reg_count].x) > max) max = abs(data[reg_count].x);
if (abs(data[reg_count].y) > max) max = abs(data[reg_count].y);
if (abs(data[reg_count].z) > max) max = abs(data[reg_count].z);
if (abs(data[reg_count].w) > max) max = abs(data[reg_count].w);
group_index += blockDim.x;
tid += blockDim.x;
reg_count++;
}
#pragma unroll
for (int i = 1; i < WARP_SIZE; i <<= 1) {
auto temp = g.shfl_xor(max, i);
if (max < temp) max = temp;
}
__shared__ float partialMax[WARP_SIZE];
if (lane == 0) partialMax[gid] = max;
b.sync();
if (lane < warp_num) max = partialMax[lane];
#pragma unroll
for (int i = 1; i < warp_num; i <<= 1) {
auto temp = g.shfl_down(max, i);
if (max < temp) max = temp;
}
max = g.shfl(max, 0);
float q_scale_val = (float)(1 << num_bits) / (max * 2 + 1e-5);
float high_q = (float)((1 << (num_bits - 1)) - 1);
float low_q = (float)(-((1 << (num_bits - 1))));
int offset = (bid)*token_size;
for (int i = 0; i < reg_count; i++) {
group_index = i * blockDim.x + threadIdx.x;
if (group_index < token_size) {
float4 q_data = data[i];
float4 q_data_int;
q_data_int.x = (float)((int)(q_data.x * q_scale_val));
q_data_int.y = (float)((int)(q_data.y * q_scale_val));
q_data_int.w = (float)((int)(q_data.w * q_scale_val));
q_data_int.z = (float)((int)(q_data.z * q_scale_val));
// Stochastic rounding
float4 rand = hiprand_uniform4(&state);
float q_error[4];
q_error[0] = abs(q_data.x - (q_data_int.x / q_scale_val)) * q_scale_val;
q_error[1] = abs(q_data.y - (q_data_int.y / q_scale_val)) * q_scale_val;
q_error[2] = abs(q_data.w - (q_data_int.w / q_scale_val)) * q_scale_val;
q_error[3] = abs(q_data.z - (q_data_int.z / q_scale_val)) * q_scale_val;
q_data_int.x =
(rand.x < q_error[0] && q_data_int.x > low_q && q_data_int.x < high_q)
? (q_data_int.x + (q_data.x > 0 ? 1 : -1))
: q_data_int.x;
q_data_int.y =
(rand.y < q_error[1] && q_data_int.y > low_q && q_data_int.y < high_q)
? (q_data_int.y + (q_data.y > 0 ? 1 : -1))
: q_data_int.y;
q_data_int.w =
(rand.w < q_error[2] && q_data_int.w > low_q && q_data_int.w < high_q)
? (q_data_int.w + (q_data.w > 0 ? 1 : -1))
: q_data_int.w;
q_data_int.z =
(rand.z < q_error[3] && q_data_int.z > low_q && q_data_int.z < high_q)
? (q_data_int.z + (q_data.z > 0 ? 1 : -1))
: q_data_int.z;
q_data_int.x /= q_scale_val;
q_data_int.y /= q_scale_val;
q_data_int.w /= q_scale_val;
q_data_int.z /= q_scale_val;
vals_cast[group_index + offset] = q_data_int;
}
}
}
}
template <typename T>
void launch_sr_quantize_kernel(T* vals,
int total_count,
int group_num,
int num_bits,
hipStream_t stream)
{
dim3 block_dim(1024);
dim3 grid_dim(group_num);
uint64_t inc = total_count / grid_dim.x / block_dim.x;
std::pair<uint64_t, uint64_t> seed = Context::Instance().IncrementOffset(inc);
hipLaunchKernelGGL(( sr_quantize_kernel), dim3(grid_dim), dim3(block_dim), 0, stream,
vals, (total_count / group_num) / 4, group_num, num_bits, seed);
}
template void launch_sr_quantize_kernel(float* vals,
int total_count,
int group_num,
int num_bits,
hipStream_t stream);
template void launch_sr_quantize_kernel(__half* vals,
int total_count,
int group_num,
int num_bits,
hipStream_t stream);
__global__ void quantize_kernel_asym(__half* vals, int group_size, int num_bits)
{
#if __CUDA_ARCH__ >= 700 || defined(__HIP_PLATFORM_HCC__)
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<32> g = cg::tiled_partition<32>(b);
int gid = threadIdx.x >> 5;
int lane = threadIdx.x & 0x1f;
int warp_num = blockDim.x >> 5;
int id = threadIdx.x;
float2* vals_cast = reinterpret_cast<float2*>(vals);
float2 data[MAX_REG];
int group_id = blockIdx.x;
{
int group_index = id;
int reg_count = 0;
int offset = group_id * group_size;
float max = -10000.0;
float min = 10000.0;
while (group_index < group_size && reg_count < MAX_REG) {
data[reg_count] = vals_cast[offset + group_index];
__half* data_h = reinterpret_cast<__half*>(&data[reg_count]);
if (((float)data_h[0]) > max) max = (float)data_h[0];
if (((float)data_h[1]) > max) max = (float)data_h[1];
if (((float)data_h[2]) > max) max = (float)data_h[2];
if (((float)data_h[3]) > max) max = (float)data_h[3];
if (((float)data_h[0]) < min) min = (float)data_h[0];
if (((float)data_h[1]) < min) min = (float)data_h[1];
if (((float)data_h[2]) < min) min = (float)data_h[2];
if (((float)data_h[3]) < min) min = (float)data_h[3];
group_index += blockDim.x;
reg_count++;
}
#pragma unroll
for (int i = 1; i < WARP_SIZE; i <<= 1) {
auto temp = g.shfl_xor(max, i);
if (max < temp) max = temp;
}
#pragma unroll
for (int i = 1; i < WARP_SIZE; i <<= 1) {
auto temp = g.shfl_xor(min, i);
if (min > temp) min = temp;
}
__shared__ float partialMax[WARP_SIZE];
__shared__ float partialMin[WARP_SIZE];
if (lane == 0) partialMax[gid] = max;
if (lane == 0) partialMin[gid] = min;
b.sync();
if (lane < warp_num) max = partialMax[lane];
if (lane < warp_num) min = partialMin[lane];
#pragma unroll
for (int i = 1; i < warp_num; i <<= 1) {
auto temp = g.shfl_down(max, i);
if (max < temp) max = temp;
}
#pragma unroll
for (int i = 1; i < warp_num; i <<= 1) {
auto temp = g.shfl_down(min, i);
if (min > temp) min = temp;
}
max = g.shfl(max, 0);
min = g.shfl(min, 0);
float q_scale = ((max - min) + 1e-5) / (float)(1 << num_bits);
float q_scale_inv = 1 / q_scale;
for (int i = 0; i < reg_count; i++) {
group_index = i * blockDim.x + id;
if (group_index < group_size) {
__half2* data_h = reinterpret_cast<__half2*>(&data[i]);
float2 q_data[2];
q_data[0] = __half22float2(data_h[0]);
q_data[1] = __half22float2(data_h[1]);
float2 q_data_int[2];
q_data_int[0].x = roundf((q_data[0].x - min) * q_scale_inv);
q_data_int[0].y = roundf((q_data[0].y - min) * q_scale_inv);
q_data_int[1].x = roundf((q_data[1].x - min) * q_scale_inv);
q_data_int[1].y = roundf((q_data[1].y - min) * q_scale_inv);
q_data_int[0].x = q_data_int[0].x * q_scale + min;
q_data_int[0].y = q_data_int[0].y * q_scale + min;
q_data_int[1].x = q_data_int[1].x * q_scale + min;
q_data_int[1].y = q_data_int[1].y * q_scale + min;
data_h[0] = __float22half2_rn(q_data_int[0]);
data_h[1] = __float22half2_rn(q_data_int[1]);
vals_cast[offset + group_index] = data[i];
}
}
}
#endif
}
__global__ void quantize_kernel_asym(float* vals, int group_size, int num_bits)
{
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<32> g = cg::tiled_partition<32>(b);
int gid = threadIdx.x >> 5;
int lane = threadIdx.x & 0x1f;
int warp_num = blockDim.x >> 5;
int id = threadIdx.x;
float4* vals_cast = reinterpret_cast<float4*>(vals);
float4 data[MAX_REG];
int bid = blockIdx.x;
int group_index = bid * group_size + id;
int reg_count = 0;
float max = -10000.0;
float min = 10000.0;
while (id < group_size && reg_count < MAX_REG) {
float4 data_reg = vals_cast[group_index];
data[reg_count] = data_reg;
if (data_reg.x > max) max = data_reg.x;
if (data_reg.y > max) max = data_reg.y;
if (data_reg.w > max) max = data_reg.w;
if (data_reg.z > max) max = data_reg.z;
if (data_reg.x < min) min = data_reg.x;
if (data_reg.y < min) min = data_reg.y;
if (data_reg.w < min) min = data_reg.w;
if (data_reg.z < min) min = data_reg.z;
group_index += blockDim.x;
id += blockDim.x;
reg_count++;
}
id = threadIdx.x;
#pragma unroll
for (int i = 1; i < WARP_SIZE; i <<= 1) {
auto temp = g.shfl_xor(max, i);
if (max < temp) max = temp;
}
#pragma unroll
for (int i = 1; i < WARP_SIZE; i <<= 1) {
auto temp = g.shfl_xor(min, i);
if (min > temp) min = temp;
}
__shared__ float partialMax[WARP_SIZE];
__shared__ float partialMin[WARP_SIZE];
if (lane == 0) partialMax[gid] = max;
if (lane == 0) partialMin[gid] = min;
b.sync();
if (lane < warp_num) max = partialMax[lane];
if (lane < warp_num) min = partialMin[lane];
#pragma unroll
for (int i = 1; i < warp_num; i <<= 1) {
auto temp = g.shfl_down(max, i);
if (max < temp) max = temp;
}
#pragma unroll
for (int i = 1; i < warp_num; i <<= 1) {
auto temp = g.shfl_down(min, i);
if (min > temp) min = temp;
}
max = g.shfl(max, 0);
min = g.shfl(min, 0);
float q_scale = ((max - min) + 1e-5) / (float)(1 << num_bits);
float q_scale_inv = 1 / q_scale;
for (int i = 0; i < reg_count; i++) {
group_index = i * blockDim.x + id;
if (group_index < group_size) {
float4 q_data;
q_data = data[i];
float4 q_data_int;
q_data_int.x = roundf((q_data.x - min) * q_scale_inv);
q_data_int.y = roundf((q_data.y - min) * q_scale_inv);
q_data_int.w = roundf((q_data.w - min) * q_scale_inv);
q_data_int.z = roundf((q_data.z - min) * q_scale_inv);
q_data.x = q_data_int.x * q_scale + min;
q_data.y = q_data_int.y * q_scale + min;
q_data.w = q_data_int.w * q_scale + min;
q_data.z = q_data_int.z * q_scale + min;
vals_cast[group_index + bid * group_size] = q_data;
}
}
}
template <typename T>
void launch_quantize_kernel_asym(T* vals,
int total_count,
int group_num,
int num_bits,
hipStream_t stream)
{
dim3 grid_dim(group_num);
dim3 block_dim(1024);
hipLaunchKernelGGL(( quantize_kernel_asym), dim3(grid_dim), dim3(block_dim), 0, stream,
vals, (total_count / group_num) / 4, num_bits);
}
template void launch_quantize_kernel_asym(float* vals,
int total_count,
int group_num,
int num_bits,
hipStream_t stream);
template void launch_quantize_kernel_asym(__half* vals,
int total_count,
int group_num,
int num_bits,
hipStream_t stream);
__global__ void sr_quantize_kernel_asym(__half* vals,
int token_size,
int token_num,
int num_bits,
std::pair<uint64_t, uint64_t> seed)
{
#if __CUDA_ARCH__ >= 700 || defined(__HIP_PLATFORM_HCC__)
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<32> g = cg::tiled_partition<32>(b);
int gid = threadIdx.x >> 5;
int lane = threadIdx.x & 0x1f;
int warp_num = blockDim.x >> 5;
int idx = blockIdx.x * blockDim.x + threadIdx.x;
float2* vals_cast = reinterpret_cast<float2*>(vals);
__half2 data_low[128];
__half2 data_high[128];
int bid = blockIdx.x;
hiprandStatePhilox4_32_10_t state;
hiprand_init(seed.first, idx, seed.second, &state);
unsigned int tid = threadIdx.x;
int reg_count = 0;
int offset = bid * token_size;
int group_index = bid * token_size + tid;
int total_count = token_size * token_num;
if (group_index < total_count) {
float min = 10000.0;
float max = -10000.0;
while (tid < token_size) {
float2 data = vals_cast[offset + tid];
__half2* data_h = reinterpret_cast<__half2*>(&data);
data_low[reg_count] = data_h[0];
data_high[reg_count] = data_h[1];
float2 data_f[2];
data_f[0] = __half22float2(data_h[0]);
data_f[1] = __half22float2(data_h[1]);
if (((float)data_f[0].x) > max) max = (float)data_f[0].x;
if (((float)data_f[0].y) > max) max = (float)data_f[0].y;
if (((float)data_f[1].x) > max) max = (float)data_f[1].x;
if (((float)data_f[1].y) > max) max = (float)data_f[1].y;
if (((float)data_f[0].x) < min) min = (float)data_f[0].x;
if (((float)data_f[0].y) < min) min = (float)data_f[0].y;
if (((float)data_f[1].x) < min) min = (float)data_f[1].x;
if (((float)data_f[1].y) < min) min = (float)data_f[1].y;
tid += blockDim.x;
reg_count++;
}
#pragma unroll
for (int i = 1; i < WARP_SIZE; i <<= 1) {
auto temp = g.shfl_xor(max, i);
if (max < temp) max = temp;
}
#pragma unroll
for (int i = 1; i < WARP_SIZE; i <<= 1) {
auto temp = g.shfl_xor(min, i);
if (min > temp) min = temp;
}
__shared__ float partialMax[WARP_SIZE];
__shared__ float partialMin[WARP_SIZE];
if (lane == 0) partialMax[gid] = max;
if (lane == 0) partialMin[gid] = min;
b.sync();
if (lane < warp_num) max = partialMax[lane];
if (lane < warp_num) min = partialMin[lane];
#pragma unroll
for (int i = 1; i < warp_num; i <<= 1) {
auto temp = g.shfl_down(max, i);
if (max < temp) max = temp;
}
#pragma unroll
for (int i = 1; i < warp_num; i <<= 1) {
auto temp = g.shfl_down(min, i);
if (min > temp) min = temp;
}
max = g.shfl(max, 0);
min = g.shfl(min, 0);
float q_scale_val = ((max - min) + 1e-5) / (float)(1 << num_bits);
float q_scale_val_inv = 1 / q_scale_val;
float high_q = (float)((1 << num_bits) - 1);
for (int i = 0; i < reg_count; i++) {
int token_index = i * blockDim.x + threadIdx.x;
if (token_index < token_size) {
float2 data_f[2];
data_f[0] = __half22float2(data_low[i]);
data_f[1] = __half22float2(data_high[i]);
float2 q_data_int[2];
q_data_int[0].x = (float)((unsigned int)((data_f[0].x - min) * q_scale_val_inv));
q_data_int[0].y = (float)((unsigned int)((data_f[0].y - min) * q_scale_val_inv));
q_data_int[1].x = (float)((unsigned int)((data_f[1].x - min) * q_scale_val_inv));
q_data_int[1].y = (float)((unsigned int)((data_f[1].y - min) * q_scale_val_inv));
// Stochastic rounding
float4 rand = hiprand_uniform4(&state);
float q_error[4];
q_error[0] =
abs(data_f[0].x - ((q_data_int[0].x * q_scale_val) + min)) * q_scale_val_inv;
q_error[1] =
abs(data_f[0].y - ((q_data_int[0].y * q_scale_val) + min)) * q_scale_val_inv;
q_error[2] =
abs(data_f[1].x - ((q_data_int[1].x * q_scale_val) + min)) * q_scale_val_inv;
q_error[3] =
abs(data_f[1].y - ((q_data_int[1].y * q_scale_val) + min)) * q_scale_val_inv;
q_data_int[0].x = (rand.x < q_error[0] && q_data_int[0].x < high_q)
? (q_data_int[0].x + 1)
: q_data_int[0].x;
q_data_int[0].y = (rand.y < q_error[1] && q_data_int[0].y < high_q)
? (q_data_int[0].y + 1)
: q_data_int[0].y;
q_data_int[1].x = (rand.w < q_error[2] && q_data_int[1].x < high_q)
? (q_data_int[1].x + 1)
: q_data_int[1].x;
q_data_int[1].y = (rand.z < q_error[3] && q_data_int[1].y < high_q)
? (q_data_int[1].y + 1)
: q_data_int[1].y;
data_f[0].x = q_data_int[0].x * q_scale_val + min;
data_f[0].y = q_data_int[0].y * q_scale_val + min;
data_f[1].x = q_data_int[1].x * q_scale_val + min;
data_f[1].y = q_data_int[1].y * q_scale_val + min;
float2 result;
__half2* result_h = reinterpret_cast<__half2*>(&result);
result_h[0] = __float22half2_rn(data_f[0]);
result_h[1] = __float22half2_rn(data_f[1]);
vals_cast[offset + token_index] = result;
}
}
}
#endif
}
__global__ void sr_quantize_kernel_asym(float* vals,
int token_size,
int token_num,
int num_bits,
std::pair<uint64_t, uint64_t> seed)
{
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<32> g = cg::tiled_partition<32>(b);
int gid = threadIdx.x >> 5;
int lane = threadIdx.x & 0x1f;
int warp_num = blockDim.x >> 5;
int id = threadIdx.x;
int idx = blockIdx.x * blockDim.x + id;
float4* vals_cast = reinterpret_cast<float4*>(vals);
float4 data[128];
int bid = blockIdx.x;
int tid = threadIdx.x;
hiprandStatePhilox4_32_10_t state;
hiprand_init(seed.first, idx, seed.second, &state);
int group_index = bid * token_size + threadIdx.x;
int reg_count = 0;
int total_count = token_size * token_num;
if (group_index < total_count) {
float min = 10000.0;
float max = -10000.0;
while (tid < token_size) {
float4 data_reg = vals_cast[group_index];
data[reg_count] = data_reg;
if (data_reg.x > max) max = data_reg.x;
if (data_reg.y > max) max = data_reg.y;
if (data_reg.w > max) max = data_reg.w;
if (data_reg.z > max) max = data_reg.z;
if (data_reg.x < min) min = data_reg.x;
if (data_reg.y < min) min = data_reg.y;
if (data_reg.w < min) min = data_reg.w;
if (data_reg.z < min) min = data_reg.z;
group_index += blockDim.x;
tid += blockDim.x;
reg_count++;
}
#pragma unroll
for (int i = 1; i < WARP_SIZE; i <<= 1) {
auto temp = g.shfl_xor(max, i);
if (max < temp) max = temp;
}
#pragma unroll
for (int i = 1; i < WARP_SIZE; i <<= 1) {
auto temp = g.shfl_xor(min, i);
if (min > temp) min = temp;
}
__shared__ float partialMax[WARP_SIZE];
__shared__ float partialMin[WARP_SIZE];
if (lane == 0) partialMax[gid] = max;
if (lane == 0) partialMin[gid] = min;
b.sync();
if (lane < warp_num) max = partialMax[lane];
if (lane < warp_num) min = partialMin[lane];
#pragma unroll
for (int i = 1; i < warp_num; i <<= 1) {
auto temp = g.shfl_down(max, i);
if (max < temp) max = temp;
}
#pragma unroll
for (int i = 1; i < warp_num; i <<= 1) {
auto temp = g.shfl_down(min, i);
if (min > temp) min = temp;
}
max = g.shfl(max, 0);
min = g.shfl(min, 0);
float q_scale_val = ((max - min) + 1e-5) / (float)(1 << num_bits);
float high_q = (float)((1 << num_bits) - 1);
int offset = (bid)*token_size;
for (int i = 0; i < reg_count; i++) {
group_index = i * blockDim.x + threadIdx.x;
if (group_index < token_size) {
float4 q_data = data[i];
float4 q_data_int;
q_data_int.x = (float)((int)((q_data.x - min) / q_scale_val));
q_data_int.y = (float)((int)((q_data.y - min) / q_scale_val));
q_data_int.w = (float)((int)((q_data.w - min) / q_scale_val));
q_data_int.z = (float)((int)((q_data.z - min) / q_scale_val));
// Stochastic rounding
float4 rand = hiprand_uniform4(&state);
float q_error[4];
q_error[0] = abs(q_data.x - ((q_data_int.x * q_scale_val) + min)) / q_scale_val;
q_error[1] = abs(q_data.y - ((q_data_int.y * q_scale_val) + min)) / q_scale_val;
q_error[2] = abs(q_data.w - ((q_data_int.w * q_scale_val) + min)) / q_scale_val;
q_error[3] = abs(q_data.z - ((q_data_int.z * q_scale_val) + min)) / q_scale_val;
q_data_int.x = (rand.x < q_error[0] && q_data_int.x < high_q) ? (q_data_int.x + 1)
: q_data_int.x;
q_data_int.y = (rand.y < q_error[1] && q_data_int.y < high_q) ? (q_data_int.y + 1)
: q_data_int.y;
q_data_int.w = (rand.w < q_error[2] && q_data_int.w < high_q) ? (q_data_int.w + 1)
: q_data_int.w;
q_data_int.z = (rand.z < q_error[3] && q_data_int.z < high_q) ? (q_data_int.z + 1)
: q_data_int.z;
q_data_int.x = q_data_int.x * q_scale_val + min;
q_data_int.y = q_data_int.y * q_scale_val + min;
q_data_int.w = q_data_int.w * q_scale_val + min;
q_data_int.z = q_data_int.z * q_scale_val + min;
vals_cast[group_index + offset] = q_data_int;
}
}
}
}
template <typename T>
void launch_sr_quantize_kernel_asym(T* vals,
int total_count,
int group_num,
int num_bits,
hipStream_t stream)
{
dim3 block_dim(1024);
dim3 grid_dim(group_num);
uint64_t inc = total_count / grid_dim.x / block_dim.x;
std::pair<uint64_t, uint64_t> seed = Context::Instance().IncrementOffset(inc);
hipLaunchKernelGGL(( sr_quantize_kernel), dim3(grid_dim), dim3(block_dim), 0, stream,
vals, (total_count / group_num) / 4, group_num, num_bits, seed);
}
template void launch_sr_quantize_kernel_asym(float* vals,
int total_count,
int group_num,
int num_bits,
hipStream_t stream);
template void launch_sr_quantize_kernel_asym(__half* vals,
int total_count,
int group_num,
int num_bits,
hipStream_t stream);
csrc/transformer/cublas_wrappers.hip
deleted
100644 → 0
View file @
7dd68788
// !!! This is a file automatically generated by hipify!!!
#include "cublas_wrappers_hip.h"
#ifdef __HIP_PLATFORM_HCC__
int cublas_gemm_ex(rocblas_handle handle,
rocblas_operation transa,
rocblas_operation transb,
int m,
int n,
int k,
const float* alpha,
const float* beta,
const float* A,
const float* B,
float* C,
rocblas_gemm_algo algo)
#else
int cublas_gemm_ex(rocblas_handle handle,
rocblas_operation transa,
rocblas_operation transb,
int m,
int n,
int k,
const float* alpha,
const float* beta,
const float* A,
const float* B,
float* C,
cublasGemmAlgo_t algo)
#endif
{
#ifdef __HIP_PLATFORM_HCC__
rocblas_status status = rocblas_gemm_ex(handle,
transa,
transb,
m,
n,
k,
(const void*)alpha,
(const void*)A,
rocblas_datatype_f32_r,
(transa == rocblas_operation_none) ? m : k,
(const void*)B,
rocblas_datatype_f32_r,
(transb == rocblas_operation_none) ? k : n,
(const void*)beta,
C,
rocblas_datatype_f32_r,
m,
C,
rocblas_datatype_f32_r,
m,
rocblas_datatype_f32_r,
algo,
0,
0);
#else
rocblas_status status = rocblas_gemmex(handle,
transa,
transb,
m,
n,
k,
(const void*)alpha,
(const void*)A,
hipR32F,
(transa == rocblas_operation_none) ? m : k,
(const void*)B,
hipR32F,
(transb == rocblas_operation_none) ? k : n,
(const void*)beta,
C,
hipR32F,
m,
hipR32F,
algo);
#endif
#ifdef __HIP_PLATFORM_HCC__
if (status != rocblas_status_success) {
#else
if (status != rocblas_status_success) {
#endif
fprintf(stderr,
"!!!! kernel execution error. (m: %d, n: %d, k: %d, error: %d) \n",
m,
n,
k,
(int)status);
return EXIT_FAILURE;
}
return 0;
}
#ifdef __HIP_PLATFORM_HCC__
int cublas_gemm_ex(rocblas_handle handle,
rocblas_operation transa,
rocblas_operation transb,
int m,
int n,
int k,
const float* alpha,
const float* beta,
const __half* A,
const __half* B,
__half* C,
rocblas_gemm_algo algo)
#else
int cublas_gemm_ex(rocblas_handle handle,
rocblas_operation transa,
rocblas_operation transb,
int m,
int n,
int k,
const float* alpha,
const float* beta,
const __half* A,
const __half* B,
__half* C,
cublasGemmAlgo_t algo)
#endif
{
#ifdef __HIP_PLATFORM_HCC__
rocblas_status status = rocblas_gemm_ex(handle,
transa,
transb,
m,
n,
k,
(const void*)alpha,
(const void*)A,
rocblas_datatype_f16_r,
(transa == rocblas_operation_none) ? m : k,
(const void*)B,
rocblas_datatype_f16_r,
(transb == rocblas_operation_none) ? k : n,
(const void*)beta,
(void*)C,
rocblas_datatype_f16_r,
m,
(void*)C,
rocblas_datatype_f16_r,
m,
rocblas_datatype_f32_r,
algo,
0,
0);
#else
rocblas_status status = rocblas_gemmex(handle,
transa,
transb,
m,
n,
k,
(const void*)alpha,
(const void*)A,
hipR16F,
(transa == rocblas_operation_none) ? m : k,
(const void*)B,
hipR16F,
(transb == rocblas_operation_none) ? k : n,
(const void*)beta,
(void*)C,
hipR16F,
m,
hipR32F,
algo);
#endif
#ifdef __HIP_PLATFORM_HCC__
if (status != rocblas_status_success) {
#else
if (status != rocblas_status_success) {
#endif
fprintf(stderr,
"!!!! kernel execution error. (m: %d, n: %d, k: %d, error: %d) \n",
m,
n,
k,
(int)status);
return EXIT_FAILURE;
}
return 0;
}
#ifdef __HIP_PLATFORM_HCC__
int cublas_strided_batched_gemm(rocblas_handle handle,
int m,
int n,
int k,
const float* alpha,
const float* beta,
const float* A,
const float* B,
float* C,
rocblas_operation op_A,
rocblas_operation op_B,
int stride_A,
int stride_B,
int stride_C,
int batch,
rocblas_gemm_algo algo)
#else
int cublas_strided_batched_gemm(rocblas_handle handle,
int m,
int n,
int k,
const float* alpha,
const float* beta,
const float* A,
const float* B,
float* C,
rocblas_operation op_A,
rocblas_operation op_B,
int stride_A,
int stride_B,
int stride_C,
int batch,
cublasGemmAlgo_t algo)
#endif
{
#ifdef __HIP_PLATFORM_HCC__
rocblas_status status =
rocblas_gemm_strided_batched_ex(handle,
op_A,
op_B,
m,
n,
k,
alpha,
A,
rocblas_datatype_f32_r,
(op_A == rocblas_operation_none) ? m : k,
stride_A,
B,
rocblas_datatype_f32_r,
(op_B == rocblas_operation_none) ? k : n,
stride_B,
beta,
C,
rocblas_datatype_f32_r,
m,
stride_C,
C,
rocblas_datatype_f32_r,
m,
stride_C,
batch,
rocblas_datatype_f32_r,
algo,
0,
0);
#else
rocblas_status status = cublasGemmStridedBatchedEx(handle,
op_A,
op_B,
m,
n,
k,
alpha,
A,
hipR32F,
(op_A == rocblas_operation_none) ? m : k,
stride_A,
B,
hipR32F,
(op_B == rocblas_operation_none) ? k : n,
stride_B,
beta,
C,
hipR32F,
m,
stride_C,
batch,
hipR32F,
algo);
#endif
#ifdef __HIP_PLATFORM_HCC__
if (status != rocblas_status_success) {
#else
if (status != rocblas_status_success) {
#endif
fprintf(stderr,
"!!!! kernel execution error. (batch: %d, m: %d, n: %d, k: %d, error: %d) \n",
batch,
m,
n,
k,
(int)status);
return EXIT_FAILURE;
}
return 0;
}
#ifdef __HIP_PLATFORM_HCC__
int cublas_strided_batched_gemm(rocblas_handle handle,
int m,
int n,
int k,
const float* alpha,
const float* beta,
const __half* A,
const __half* B,
__half* C,
rocblas_operation op_A,
rocblas_operation op_B,
int stride_A,
int stride_B,
int stride_C,
int batch,
rocblas_gemm_algo algo)
#else
int cublas_strided_batched_gemm(rocblas_handle handle,
int m,
int n,
int k,
const float* alpha,
const float* beta,
const __half* A,
const __half* B,
__half* C,
rocblas_operation op_A,
rocblas_operation op_B,
int stride_A,
int stride_B,
int stride_C,
int batch,
cublasGemmAlgo_t algo)
#endif
{
#ifdef __HIP_PLATFORM_HCC__
rocblas_status status =
rocblas_gemm_strided_batched_ex(handle,
op_A,
op_B,
m,
n,
k,
alpha,
A,
rocblas_datatype_f16_r,
(op_A == rocblas_operation_none) ? m : k,
stride_A,
B,
rocblas_datatype_f16_r,
(op_B == rocblas_operation_none) ? k : n,
stride_B,
beta,
C,
rocblas_datatype_f16_r,
m,
stride_C,
C,
rocblas_datatype_f16_r,
m,
stride_C,
batch,
rocblas_datatype_f32_r,
algo,
0,
0);
#else
rocblas_status status = cublasGemmStridedBatchedEx(handle,
op_A,
op_B,
m,
n,
k,
alpha,
A,
hipR16F,
(op_A == rocblas_operation_none) ? m : k,
stride_A,
B,
hipR16F,
(op_B == rocblas_operation_none) ? k : n,
stride_B,
beta,
C,
hipR16F,
m,
stride_C,
batch,
hipR32F,
algo);
#endif
#ifdef __HIP_PLATFORM_HCC__
if (status != rocblas_status_success) {
#else
if (status != rocblas_status_success) {
#endif
fprintf(stderr,
"!!!! kernel execution error. (m: %d, n: %d, k: %d, error: %d) \n",
m,
n,
k,
(int)status);
return EXIT_FAILURE;
}
return 0;
}
csrc/transformer/dropout_kernels.hip
deleted
100644 → 0
View file @
7dd68788
// !!! This is a file automatically generated by hipify!!!
#include "hip/hip_runtime.h"
#include "custom_hip_layers.h"
const int unroll_factor = 4;
__global__ void dropout_kernel(const int N,
const float ratio,
float* out,
const float* Xdata,
uint8_t* mask,
std::pair<uint64_t, uint64_t> seed)
{
const float scale = 1. / (1. - ratio);
int idx = blockIdx.x * blockDim.x + threadIdx.x;
hiprandStatePhilox4_32_10_t state;
hiprand_init(seed.first, idx, seed.second, &state);
CUDA_1D_KERNEL_LOOP(j, N / unroll_factor)
{
float4 rand = hiprand_uniform4(&state);
uint8_t m[unroll_factor];
m[0] = (uint8_t)(rand.x > ratio);
m[1] = (uint8_t)(rand.y > ratio);
m[2] = (uint8_t)(rand.z > ratio);
m[3] = (uint8_t)(rand.w > ratio);
int i = j * unroll_factor;
mask[i] = (uint8_t)m[0];
mask[i + 1] = (uint8_t)m[1];
mask[i + 2] = (uint8_t)m[2];
mask[i + 3] = (uint8_t)m[3];
out[i] = Xdata[i] * scale * m[0];
out[i + 1] = Xdata[i + 1] * scale * m[1];
out[i + 2] = Xdata[i + 2] * scale * m[2];
out[i + 3] = Xdata[i + 3] * scale * m[3];
}
int high_index =
((((N / unroll_factor) - 1) / blockDim.x + 1) * (unroll_factor * blockDim.x)) + threadIdx.x;
if (N > high_index) {
float4 rand = hiprand_uniform4(&state);
float* rand_data = &(rand.x);
int k = 0;
for (int i = high_index; i < N; i++) {
uint8_t m = (uint8_t)(rand_data[k++] > ratio);
out[i] = Xdata[i] * scale * m;
mask[i] = m;
}
}
}
__global__ void dropout_kernel(const int N,
const float ratio,
__half* out,
const __half* Xdata,
uint8_t* mask,
std::pair<uint64_t, uint64_t> seed)
{
const float scale = 1. / (1. - ratio);
int idx = blockIdx.x * blockDim.x + threadIdx.x;
hiprandStatePhilox4_32_10_t state;
hiprand_init(seed.first, idx, seed.second, &state);
#ifdef __STOCHASTIC_MODE__
const __half2 h_scale = __float2half2_rn(scale);
const float2* x_cast = reinterpret_cast<const float2*>(Xdata);
float2* out_cast = reinterpret_cast<float2*>(out);
uint32_t* mask_cast = reinterpret_cast<uint32_t*>(mask);
uint32_t m_32;
uint8_t* m = reinterpret_cast<uint8_t*>(&m_32);
float2 result_f;
__half2* result_h = reinterpret_cast<__half2*>(&result_f);
__half2 mask_h[2];
float2 mask_f[2];
CUDA_1D_KERNEL_LOOP(j, N / unroll_factor)
{
float2 x_f = x_cast[j];
__half2* x_h = reinterpret_cast<__half2*>(&x_f);
float4 rand = hiprand_uniform4(&state);
m[0] = (uint8_t)(rand.x > ratio);
m[1] = (uint8_t)(rand.y > ratio);
m[2] = (uint8_t)(rand.z > ratio);
m[3] = (uint8_t)(rand.w > ratio);
float* mask_f_data = &mask_f[0].x;
#pragma unroll
for (int i = 0; i < unroll_factor; i++) mask_f_data[i] = (float)(m[i]);
mask_h[0] = __float22half2_rn(mask_f[0]);
mask_h[1] = __float22half2_rn(mask_f[1]);
result_h[0] = x_h[0] * h_scale * mask_h[0];
result_h[1] = x_h[1] * h_scale * mask_h[1];
out_cast[j] = result_f;
mask_cast[j] = m_32;
}
#else
CUDA_1D_KERNEL_LOOP(j, N / unroll_factor)
{
int i = j * unroll_factor;
const __half2* vals_half = reinterpret_cast<const __half2*>(Xdata + i);
float2 vals_half_f[2];
vals_half_f[0] = __half22float2(vals_half[0]);
vals_half_f[1] = __half22float2(vals_half[1]);
uint8_t m[unroll_factor];
float4 rand = hiprand_uniform4(&state);
m[0] = (uint8_t)(rand.x > ratio);
m[1] = (uint8_t)(rand.y > ratio);
m[2] = (uint8_t)(rand.z > ratio);
m[3] = (uint8_t)(rand.w > ratio);
out[i] = __float2half(vals_half_f[0].x * scale * m[0]);
out[i + 1] = __float2half(vals_half_f[0].y * scale * m[1]);
out[i + 2] = __float2half(vals_half_f[1].x * scale * m[2]);
out[i + 3] = __float2half(vals_half_f[1].y * scale * m[3]);
mask[i] = m[0];
mask[i + 1] = m[1];
mask[i + 2] = m[2];
mask[i + 3] = m[3];
}
#endif
int high_index =
((((N / unroll_factor) - 1) / blockDim.x + 1) * (unroll_factor * blockDim.x)) + threadIdx.x;
if (N > high_index) {
float4 rand = hiprand_uniform4(&state);
float* rand_data = &(rand.x);
int k = 0;
for (int i = high_index; i < N; i++) {
uint8_t m = (uint8_t)(rand_data[k++] > ratio);
out[i] = __float2half((float)Xdata[i] * scale * m);
mask[i] = m;
}
}
}
__global__ void dropout_kernel_bwd(const int N,
const float ratio,
const float* Xdata,
float* out,
uint8_t* mask,
std::pair<uint64_t, uint64_t> seed)
{
const float scale = 1. / (1. - ratio);
CUDA_1D_KERNEL_LOOP(j, N / unroll_factor)
{
int i = j * unroll_factor;
out[i] = mask[i] ? Xdata[i] * scale : 0.0;
out[i + 1] = mask[i + 1] ? Xdata[i + 1] * scale : 0.0;
out[i + 2] = mask[i + 2] ? Xdata[i + 2] * scale : 0.0;
out[i + 3] = mask[i + 3] ? Xdata[i + 3] * scale : 0.0;
}
int high_index =
((((N / unroll_factor) - 1) / blockDim.x + 1) * (unroll_factor * blockDim.x)) + threadIdx.x;
if (N > high_index) {
for (int i = high_index; i < N; i++) { out[i] = mask[i] ? Xdata[i] * scale : 0.0; }
}
}
__global__ void dropout_kernel_bwd(const int N,
const float ratio,
const __half* Xdata,
__half* out,
uint8_t* mask,
std::pair<uint64_t, uint64_t> seed)
{
const float scale = 1. / (1. - ratio);
#ifdef __STOCHASTIC_MODE__
const __half2 h_scale = __float2half2_rn(scale);
const float2* x_cast = reinterpret_cast<const float2*>(Xdata);
float2* out_cast = reinterpret_cast<float2*>(out);
uint32_t* mask_cast = reinterpret_cast<uint32_t*>(mask);
CUDA_1D_KERNEL_LOOP(j, N / unroll_factor)
{
float2 x_f = x_cast[j];
__half2* x_h = reinterpret_cast<__half2*>(&x_f);
uint32_t m_32 = mask_cast[j];
uint8_t* m = (uint8_t*)&m_32;
__half2 mask_h[2];
float2 mask_f[2];
float* mask_f_data = &mask_f[0].x;
#pragma unroll
for (int i = 0; i < unroll_factor; i++) mask_f_data[i] = (float)(m[i]);
#pragma unroll
for (int i = 0; i < 2; i++) mask_h[i] = __float22half2_rn(mask_f[i]);
float2 result_f;
__half2* result_h = reinterpret_cast<__half2*>(&result_f);
result_h[0] = x_h[0] * h_scale * mask_h[0];
result_h[1] = x_h[1] * h_scale * mask_h[1];
out_cast[j] = result_f;
}
#else
const __half h_scale = __float2half(scale);
const __half h_zero = __float2half(0.0);
CUDA_1D_KERNEL_LOOP(j, N / unroll_factor)
{
int i = j * unroll_factor;
const __half2* vals_half = reinterpret_cast<const __half2*>(Xdata + i);
uint8_t* m = mask + i;
float2 vals_half_f[2];
vals_half_f[0] = __half22float2(vals_half[0]);
vals_half_f[1] = __half22float2(vals_half[1]);
out[i] = __float2half(vals_half_f[0].x * scale * m[0]);
out[i + 1] = __float2half(vals_half_f[0].y * scale * m[1]);
out[i + 2] = __float2half(vals_half_f[1].x * scale * m[2]);
out[i + 3] = __float2half(vals_half_f[1].y * scale * m[3]);
}
#endif
int high_index =
((((N / unroll_factor) - 1) / blockDim.x + 1) * (unroll_factor * blockDim.x)) + threadIdx.x;
if (N > high_index) {
for (int i = high_index; i < N; i++) {
out[i] = __float2half((float)Xdata[i] * scale * mask[i]);
}
}
}
template <typename T>
void launch_dropout(T* out,
const T* vals,
uint8_t* mask,
int total_count,
int dim,
float ratio,
hipStream_t stream,
bool bwd)
{
assert(unroll_factor == 4);
dim3 grid_dim = DS_GET_BLOCKS(total_count / unroll_factor);
dim3 block_dim = DS_CUDA_NUM_THREADS;
if (dim > 512) {
block_dim.x >>= 1;
grid_dim.x <<= 1;
}
uint64_t inc = total_count / grid_dim.x / block_dim.x;
std::pair<uint64_t, uint64_t> seed = Context::Instance().IncrementOffset(inc);
if (bwd)
hipLaunchKernelGGL(( dropout_kernel_bwd), dim3(grid_dim), dim3(block_dim), 0, stream,
total_count, ratio, vals, out, mask, seed);
else
hipLaunchKernelGGL(( dropout_kernel), dim3(grid_dim), dim3(block_dim), 0, stream,
total_count, ratio, out, vals, mask, seed);
}
template void launch_dropout(float* out,
const float* vals,
uint8_t* mask,
int total_count,
int dim,
float ratio,
hipStream_t stream,
bool);
template void launch_dropout(__half* out,
const __half* vals,
uint8_t* mask,
int total_count,
int dim,
float ratio,
hipStream_t stream,
bool);
__global__ void dropout_grad_kernel(const int N, const float scale, float* Xdata, uint8_t* mask)
{
CUDA_1D_KERNEL_LOOP(i, N) { Xdata[i] *= scale * mask[i]; }
}
__global__ void dropout_grad_kernel(const int N, const float scale, __half* Xdata, uint8_t* mask)
{
const __half2 h_scale = __float2half2_rn(scale);
float2* x_cast = reinterpret_cast<float2*>(Xdata);
uint32_t* mask_cast = reinterpret_cast<uint32_t*>(mask);
CUDA_1D_KERNEL_LOOP(j, N / unroll_factor)
{
float2 x_data = x_cast[j];
uint32_t m_32 = mask_cast[j];
uint8_t* m = (uint8_t*)&m_32;
float2 result_f;
__half2* result_h = reinterpret_cast<__half2*>(&result_f);
#ifdef __STOCHASTIC_MODE__
__half2* x_data_h = reinterpret_cast<__half2*>(&x_data);
__half2 mask_h[2];
float2 mask_f[2];
float* mask_f_data = &mask_f[0].x;
#pragma unroll
for (int i = 0; i < unroll_factor; i++) *(mask_f_data++) = (float)(m[i]);
mask_h[0] = __float22half2_rn(mask_f[0]);
mask_h[1] = __float22half2_rn(mask_f[1]);
result_h[0] = x_data_h[0] * h_scale * mask_h[0];
result_h[1] = x_data_h[1] * h_scale * mask_h[1];
#else
__half* x_data_h = reinterpret_cast<__half*>(&x_data);
float2 result[2];
result[0].x = (float)x_data_h[0] * scale * m[0];
result[0].y = (float)x_data_h[1] * scale * m[1];
result[1].x = (float)x_data_h[2] * scale * m[2];
result[1].y = (float)x_data_h[3] * scale * m[3];
result_h[0] = __float22half2_rn(result[0]);
result_h[1] = __float22half2_rn(result[1]);
#endif
x_cast[j] = result_f;
}
int high_index =
((((N / unroll_factor) - 1) / blockDim.x + 1) * (unroll_factor * blockDim.x)) + threadIdx.x;
if (N > high_index) {
for (int i = high_index; i < N; i++) {
Xdata[i] = __float2half((float)Xdata[i] * scale * mask[i]);
}
}
}
template <typename T>
void launch_dropout_grad(T* vals, uint8_t* mask, int total_count, float ratio, hipStream_t stream)
{
assert(unroll_factor == 4);
const float scale = 1. / (1. - ratio);
hipLaunchKernelGGL(( dropout_grad_kernel), dim3(DS_GET_BLOCKS(total_count / unroll_factor)),
dim3(DS_CUDA_NUM_THREADS),
0,
stream, total_count, scale, vals, mask);
}
template void launch_dropout_grad(float* vals,
uint8_t* mask,
int total_count,
float ratio,
hipStream_t stream);
template void launch_dropout_grad(__half* vals,
uint8_t* mask,
int total_count,
float ratio,
hipStream_t stream);
__global__ void dropout_grad_kernel(const int N,
const float scale,
const float* Xdata,
float* out,
uint8_t* mask)
{
CUDA_1D_KERNEL_LOOP(i, N) { out[i] = Xdata[i] * scale * mask[i]; }
}
__global__ void dropout_grad_kernel(const int N,
const float scale,
const __half* Xdata,
__half* out,
uint8_t* mask)
{
const float2* x_cast = reinterpret_cast<const float2*>(Xdata);
float2* out_cast = reinterpret_cast<float2*>(out);
const uint32_t* mask_cast = reinterpret_cast<const uint32_t*>(mask);
float2 result_f;
__half2* result_h = reinterpret_cast<__half2*>(&result_f);
CUDA_1D_KERNEL_LOOP(j, N / unroll_factor)
{
float2 x_data = x_cast[j];
uint32_t m_32 = mask_cast[j];
uint8_t* m = (uint8_t*)&m_32;
__half* x_data_h = reinterpret_cast<__half*>(&x_data);
float2 result[2];
result[0].x = (float)x_data_h[0] * scale * m[0];
result[0].y = (float)x_data_h[1] * scale * m[1];
result[1].x = (float)x_data_h[2] * scale * m[2];
result[1].y = (float)x_data_h[3] * scale * m[3];
result_h[0] = __float22half2_rn(result[0]);
result_h[1] = __float22half2_rn(result[1]);
out_cast[j] = result_f;
}
int high_index =
((((N / unroll_factor) - 1) / blockDim.x + 1) * (unroll_factor * blockDim.x)) + threadIdx.x;
if (N > high_index) {
for (int i = high_index; i < N; i++) {
out[i] = __float2half((float)Xdata[i] * scale * mask[i]);
}
}
}
template <typename T>
void launch_dropout_grad(T* vals_out,
const T* vals,
uint8_t* mask,
int total_count,
float ratio,
hipStream_t stream)
{
assert(unroll_factor == 4);
const float scale = 1. / (1. - ratio);
hipLaunchKernelGGL(( dropout_grad_kernel), dim3(DS_GET_BLOCKS(total_count / unroll_factor)),
dim3(DS_CUDA_NUM_THREADS),
0,
stream, total_count, scale, vals, vals_out, mask);
}
template void launch_dropout_grad(float*,
const float* vals,
uint8_t* mask,
int total_count,
float ratio,
hipStream_t stream);
template void launch_dropout_grad(__half*,
const __half* vals,
uint8_t* mask,
int total_count,
float ratio,
hipStream_t stream);
__global__ void dropout_kernel(const int N,
const int dim,
const float ratio,
const float* bias,
float* Xdata,
uint8_t* mask,
std::pair<uint64_t, uint64_t> seed)
{
const float scale = 1. / (1. - ratio);
int idx = blockIdx.x * blockDim.x + threadIdx.x;
int tid = threadIdx.x % (dim / unroll_factor);
hiprandStatePhilox4_32_10_t state;
hiprand_init(seed.first, idx, seed.second, &state);
float4* Xdata_cast = reinterpret_cast<float4*>(Xdata);
uint32_t* mask_32 = reinterpret_cast<uint32_t*>(mask);
const float4* bias_cast = reinterpret_cast<const float4*>(bias);
CUDA_1D_KERNEL_LOOP(j, N)
{
float4 rand = hiprand_uniform4(&state);
uint32_t m_32;
uint8_t* m = (uint8_t*)&m_32;
m[0] = (uint8_t)(rand.x > ratio);
m[1] = (uint8_t)(rand.y > ratio);
m[2] = (uint8_t)(rand.z > ratio);
m[3] = (uint8_t)(rand.w > ratio);
float4 x_data = Xdata_cast[j];
float4 b_data = bias_cast[j % (dim / unroll_factor)];
x_data.x += b_data.x;
x_data.y += b_data.y;
x_data.z += b_data.z;
x_data.w += b_data.w;
x_data.x = x_data.x * scale * m[0];
x_data.y = x_data.y * scale * m[1];
x_data.z = x_data.z * scale * m[2];
x_data.w = x_data.w * scale * m[3];
mask_32[j] = m_32;
Xdata_cast[j] = x_data;
}
int high_index =
((((N / unroll_factor) - 1) / blockDim.x + 1) * (unroll_factor * blockDim.x)) + threadIdx.x;
if (N > high_index) {
float4 rand = hiprand_uniform4(&state);
float* rand_data = &(rand.x);
int k = 0;
for (int i = high_index; i < N; i++) {
float x_data = Xdata[i] + bias[i % dim];
uint8_t m = (uint8_t)(rand_data[k++] > ratio);
Xdata[i] = x_data * scale * m;
mask[i] = m;
}
}
}
__global__ void dropout_kernel(const int N,
const int dim,
const float ratio,
const __half* bias,
__half* Xdata,
uint8_t* mask,
std::pair<uint64_t, uint64_t> seed)
{
const float scale = 1. / (1. - ratio);
int idx = blockIdx.x * blockDim.x + threadIdx.x;
int tid = threadIdx.x % (dim / unroll_factor);
hiprandStatePhilox4_32_10_t state;
hiprand_init(seed.first, idx, seed.second, &state);
float2* Xdata_cast = reinterpret_cast<float2*>(Xdata);
uint32_t* mask_32 = reinterpret_cast<uint32_t*>(mask);
const float2* bias_cast = reinterpret_cast<const float2*>(bias);
CUDA_1D_KERNEL_LOOP(j, N)
{
float4 rand = hiprand_uniform4(&state);
float2 data_f;
__half2* data_h = reinterpret_cast<__half2*>(&data_f);
float2 bias_f;
__half2* bias_h = reinterpret_cast<__half2*>(&bias_f);
data_f = Xdata_cast[j];
bias_f = bias_cast[j % (dim / unroll_factor)];
float2 data_h_0 = __half22float2(data_h[0]);
float2 data_h_1 = __half22float2(data_h[1]);
float2 bias_h_0 = __half22float2(bias_h[0]);
float2 bias_h_1 = __half22float2(bias_h[1]);
data_h_0.x += bias_h_0.x;
data_h_0.y += bias_h_0.y;
data_h_1.x += bias_h_1.x;
data_h_1.y += bias_h_1.y;
uint32_t m_32;
uint8_t* m = (uint8_t*)&m_32;
m[0] = (uint8_t)(rand.x > ratio);
m[1] = (uint8_t)(rand.y > ratio);
m[2] = (uint8_t)(rand.z > ratio);
m[3] = (uint8_t)(rand.w > ratio);
data_h_0.x = __float2half(data_h_0.x * scale * m[0]);
data_h_0.y = __float2half(data_h_0.y * scale * m[1]);
data_h_1.x = __float2half(data_h_1.x * scale * m[2]);
data_h_1.y = __float2half(data_h_1.y * scale * m[3]);
float2 result_f;
__half2* result_h = reinterpret_cast<__half2*>(&result_f);
result_h[0] = __float22half2_rn(data_h_0);
result_h[1] = __float22half2_rn(data_h_1);
Xdata_cast[j] = result_f;
mask_32[j] = m_32;
}
int high_index =
((((N / unroll_factor) - 1) / blockDim.x + 1) * (unroll_factor * blockDim.x)) + threadIdx.x;
if (N > high_index) {
float4 rand = hiprand_uniform4(&state);
float* rand_data = &(rand.x);
int k = 0;
for (int i = high_index; i < N; i++) {
float x_data = (float)Xdata[i] + (float)bias[i % dim];
uint8_t m = (uint8_t)(rand_data[k++] > ratio);
Xdata[i] = __float2half(x_data * scale * m);
mask[i] = m;
}
}
}
template <typename T>
void launch_dropout(T* out,
const T* bias,
uint8_t* mask,
int batch,
int dim,
float ratio,
hipStream_t stream)
{
assert(unroll_factor == 4);
int total_count = batch * dim / unroll_factor;
dim3 grid_dim = DS_GET_BLOCKS(total_count);
dim3 block_dim = DS_CUDA_NUM_THREADS;
uint64_t inc = (batch * dim) / grid_dim.x / block_dim.x;
std::pair<uint64_t, uint64_t> seed = Context::Instance().IncrementOffset(inc);
hipLaunchKernelGGL(( dropout_kernel), dim3(grid_dim), dim3(block_dim), 0, stream,
total_count, dim, ratio, bias, out, mask, seed);
}
template void launch_dropout(float*,
const float* bias,
uint8_t* mask,
int batch,
int dim,
float ratio,
hipStream_t stream);
template void launch_dropout(__half*,
const __half* bias,
uint8_t* mask,
int batch,
int dim,
float ratio,
hipStream_t stream);
__global__ void dropout_kernel(const int N,
const int dim,
const float ratio,
const float* input,
const float* residual,
const float* bias,
float* out,
uint8_t* mask,
std::pair<uint64_t, uint64_t> seed)
{
const float scale = 1. / (1. - ratio);
int idx = blockIdx.x * blockDim.x + threadIdx.x;
int tid = threadIdx.x % (dim / unroll_factor);
hiprandStatePhilox4_32_10_t state;
hiprand_init(seed.first, idx, seed.second, &state);
float4* out_cast = reinterpret_cast<float4*>(out);
uint32_t* mask_32 = reinterpret_cast<uint32_t*>(mask);
const float4* bias_cast = reinterpret_cast<const float4*>(bias);
const float4* residual_cast = reinterpret_cast<const float4*>(residual);
const float4* input_cast = reinterpret_cast<const float4*>(input);
CUDA_1D_KERNEL_LOOP(j, N)
{
float4 rand = hiprand_uniform4(&state);
uint32_t m_32;
uint8_t* m = (uint8_t*)&m_32;
m[0] = (uint8_t)(rand.x > ratio);
m[1] = (uint8_t)(rand.y > ratio);
m[2] = (uint8_t)(rand.z > ratio);
m[3] = (uint8_t)(rand.w > ratio);
float4 out_data;
float4 b_data = bias_cast[j % (dim / unroll_factor)];
float4 res_data = residual_cast[j];
float4 inp_data = input_cast[j];
out_data.x = (b_data.x + inp_data.x);
out_data.y = (b_data.y + inp_data.y);
out_data.z = (b_data.z + inp_data.z);
out_data.w = (b_data.w + inp_data.w);
out_data.x = out_data.x * scale * m[0];
out_data.y = out_data.y * scale * m[1];
out_data.z = out_data.z * scale * m[2];
out_data.w = out_data.w * scale * m[3];
out_data.x += res_data.x;
out_data.y += res_data.y;
out_data.z += res_data.z;
out_data.w += res_data.w;
mask_32[j] = m_32;
out_cast[j] = out_data;
}
int high_index =
((((N / unroll_factor) - 1) / blockDim.x + 1) * (unroll_factor * blockDim.x)) + threadIdx.x;
if (N > high_index) {
float4 rand = hiprand_uniform4(&state);
float* rand_data = &(rand.x);
int k = 0;
for (int i = high_index; i < N; i++) {
float x_data = input[i] + bias[i % dim];
uint8_t m = (uint8_t)(rand_data[k++] > ratio);
x_data = x_data * scale * m;
x_data += residual[i];
out[i] = x_data;
mask[i] = m;
}
}
}
__global__ void dropout_kernel(const int N,
const int dim,
const float ratio,
const __half* input,
const __half* residual,
const __half* bias,
__half* out,
uint8_t* mask,
std::pair<uint64_t, uint64_t> seed)
{
const float scale = 1. / (1. - ratio);
int idx = blockIdx.x * blockDim.x + threadIdx.x;
int tid = threadIdx.x % (dim / unroll_factor);
hiprandStatePhilox4_32_10_t state;
hiprand_init(seed.first, idx, seed.second, &state);
float2* out_cast = reinterpret_cast<float2*>(out);
uint32_t* mask_32 = reinterpret_cast<uint32_t*>(mask);
const float2* bias_cast = reinterpret_cast<const float2*>(bias);
const float2* residual_cast = reinterpret_cast<const float2*>(residual);
const float2* input_cast = reinterpret_cast<const float2*>(input);
CUDA_1D_KERNEL_LOOP(j, N)
{
float4 rand = hiprand_uniform4(&state);
float2 data_f;
__half2* data_h = reinterpret_cast<__half2*>(&data_f);
float2 bias_f;
__half2* bias_h = reinterpret_cast<__half2*>(&bias_f);
float2 residual_f;
__half2* residual_h = reinterpret_cast<__half2*>(&residual_f);
float2 input_f;
__half2* input_h = reinterpret_cast<__half2*>(&input_f);
bias_f = bias_cast[j % (dim / unroll_factor)];
residual_f = residual_cast[j];
input_f = input_cast[j];
float2 data_h_0 = __half22float2(data_h[0]);
float2 data_h_1 = __half22float2(data_h[1]);
float2 bias_h_0 = __half22float2(bias_h[0]);
float2 bias_h_1 = __half22float2(bias_h[1]);
float2 residual_h_0 = __half22float2(residual_h[0]);
float2 residual_h_1 = __half22float2(residual_h[1]);
float2 input_h_0 = __half22float2(input_h[0]);
float2 input_h_1 = __half22float2(input_h[1]);
data_h_0.x = (bias_h_0.x + input_h_0.x);
data_h_0.y = (bias_h_0.y + input_h_0.y);
data_h_1.x = (bias_h_1.x + input_h_1.x);
data_h_1.y = (bias_h_1.y + input_h_1.y);
uint32_t m_32;
uint8_t* m = (uint8_t*)&m_32;
m[0] = (uint8_t)(rand.x > ratio);
m[1] = (uint8_t)(rand.y > ratio);
m[2] = (uint8_t)(rand.z > ratio);
m[3] = (uint8_t)(rand.w > ratio);
data_h_0.x = __float2half(data_h_0.x * scale * m[0]);
data_h_0.y = __float2half(data_h_0.y * scale * m[1]);
data_h_1.x = __float2half(data_h_1.x * scale * m[2]);
data_h_1.y = __float2half(data_h_1.y * scale * m[3]);
data_h_0.x += residual_h_0.x;
data_h_0.y += residual_h_0.y;
data_h_1.x += residual_h_1.x;
data_h_1.y += residual_h_1.y;
float2 result_f;
__half2* result_h = reinterpret_cast<__half2*>(&result_f);
result_h[0] = __float22half2_rn(data_h_0);
result_h[1] = __float22half2_rn(data_h_1);
out_cast[j] = result_f;
mask_32[j] = m_32;
}
int high_index =
((((N / unroll_factor) - 1) / blockDim.x + 1) * (unroll_factor * blockDim.x)) + threadIdx.x;
if (N > high_index) {
float4 rand = hiprand_uniform4(&state);
float* rand_data = &(rand.x);
int k = 0;
for (int i = high_index; i < N; i++) {
float x_data = (float)input[i] + (float)bias[i % dim];
uint8_t m = (uint8_t)(rand_data[k++] > ratio);
x_data = x_data * scale * m;
x_data += (float)residual[i];
out[i] = __float2half(x_data);
mask[i] = m;
}
}
}
template <typename T>
void launch_dropout(T* out,
const T* input,
const T* residual,
const T* bias,
uint8_t* mask,
int batch,
int dim,
float ratio,
hipStream_t stream)
{
assert(unroll_factor == 4);
int total_count = batch * dim / unroll_factor;
dim3 grid_dim = DS_GET_BLOCKS(total_count);
dim3 block_dim = DS_CUDA_NUM_THREADS;
uint64_t inc = (batch * dim) / grid_dim.x / block_dim.x;
std::pair<uint64_t, uint64_t> seed = Context::Instance().IncrementOffset(inc);
hipLaunchKernelGGL(( dropout_kernel), dim3(grid_dim), dim3(block_dim), 0, stream,
total_count, dim, ratio, input, residual, bias, out, mask, seed);
}
template void launch_dropout(float*,
const float*,
const float* residual,
const float* bias,
uint8_t* mask,
int batch,
int dim,
float ratio,
hipStream_t stream);
template void launch_dropout(__half*,
const __half*,
const __half* residual,
const __half* bias,
uint8_t* mask,
int batch,
int dim,
float ratio,
hipStream_t stream);
csrc/transformer/ds_transformer_hip.cpp
deleted
100644 → 0
View file @
7dd68788
// !!! This is a file automatically generated by hipify!!!
#include <torch/extension.h>
#include <rocblas.h>
#include <hip/hip_fp16.h>
#include <hip/hip_runtime.h>
#include <type_traits>
#include <unordered_map>
#include <vector>
#include "Timer_hip.h"
#include "context_hip.h"
#include "cublas_wrappers_hip.h"
#include "custom_hip_layers.h"
#include "ds_transformer_hip.h"
static
std
::
unordered_map
<
int
,
std
::
shared_ptr
<
void
>>
s_transformer_layers
;
const
int
init_seq_length
=
128
;
// C++ interface
template
<
typename
T
>
unsigned
get_workspace_size
(
unsigned
maxBatchSize
,
unsigned
seq_len
,
unsigned
hidden_size
,
unsigned
intermediate_size
,
unsigned
heads
,
bool
training
,
bool
gelu_checkpoint
)
{
unsigned
workSpacesize
=
4
*
(
size_t
(
maxBatchSize
)
*
seq_len
*
hidden_size
);
if
(
training
)
{
workSpacesize
+=
2
*
(
size_t
(
maxBatchSize
)
*
seq_len
*
hidden_size
);
workSpacesize
+=
((
std
::
max
)((
size_t
(
maxBatchSize
)
*
seq_len
*
intermediate_size
),
2
*
(
size_t
(
maxBatchSize
)
*
heads
*
seq_len
*
seq_len
)));
if
(
gelu_checkpoint
)
workSpacesize
+=
2
*
(
size_t
(
maxBatchSize
)
*
seq_len
*
intermediate_size
);
}
return
workSpacesize
;
// * sizeof(T);
}
// NOTE: AT_ASSERT has become AT_CHECK on master after 0.4.
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) \
CHECK_CUDA(x); \
CHECK_CONTIGUOUS(x)
template
<
typename
T
>
BertTransformerLayer
<
T
>::
BertTransformerLayer
(
unsigned
layer_id
,
unsigned
batch_size
,
unsigned
hidden_size
,
unsigned
num_heads
,
unsigned
intermediate_size
,
unsigned
seq_length
,
float
attn_prob_dropout_ratio
,
float
hidden_output_dropout_ratio
,
float
layer_norm_eps
,
bool
pre_or_postLayerNorm
,
const
std
::
vector
<
std
::
array
<
int
,
3
>>&
gemm_algos
,
bool
attn_dropout_checkpoint
,
bool
normalize_invertible
,
bool
gelu_checkpoint
,
bool
stochastic_mode
)
:
_layer_id
(
layer_id
),
_batch_size
(
batch_size
),
_hidden_size
(
hidden_size
),
_heads
(
num_heads
),
_intermediate_size
(
intermediate_size
),
_seq_length
(
seq_length
),
_training
(
true
),
_pre_or_postLayerNorm
(
pre_or_postLayerNorm
),
_attn_dropout_checkpoint
(
attn_dropout_checkpoint
),
_normalize_invertible
(
normalize_invertible
),
_gelu_checkpoint
(
gelu_checkpoint
),
_stochastic_mode
(
stochastic_mode
),
_stream
(
Context
::
Instance
().
GetCurrentStream
()),
_cublasHandle
(
Context
::
Instance
().
GetCublasHandle
()),
_qkv_linear
(
typename
FeedForward
<
T
>::
Config
(
batch_size
*
seq_length
,
3
*
hidden_size
,
hidden_size
,
gemm_algos
[
0
])),
_attn_out_linear
(
typename
FeedForward
<
T
>::
Config
(
batch_size
*
seq_length
,
hidden_size
,
hidden_size
,
gemm_algos
[
0
])),
_attn_layer_norm
(
typename
Normalize_Layer
<
T
>::
Config
(
batch_size
,
seq_length
,
hidden_size
,
layer_norm_eps
,
true
,
!
normalize_invertible
)),
_layer_norm
(
typename
Normalize_Layer
<
T
>::
Config
(
batch_size
,
seq_length
,
hidden_size
,
layer_norm_eps
,
true
,
!
normalize_invertible
)),
_ff1
(
typename
FeedForward
<
T
>::
Config
(
batch_size
*
seq_length
,
_intermediate_size
,
hidden_size
,
gemm_algos
[
1
])),
_ff2
(
typename
FeedForward
<
T
>::
Config
(
batch_size
*
seq_length
,
hidden_size
,
_intermediate_size
,
gemm_algos
[
2
])),
_softmax
(
typename
Softmax
<
T
>::
Config
(
batch_size
,
num_heads
,
seq_length
)),
_gelu
(
typename
Gelu
<
T
>::
Config
(
_intermediate_size
)),
_attn_prob_dropout
(
typename
Dropout
<
T
>::
Config
(
attn_prob_dropout_ratio
,
_seq_length
)),
_attn_output_dropout
(
typename
Dropout
<
T
>::
Config
(
hidden_output_dropout_ratio
,
_hidden_size
)),
_layer_output_dropout
(
typename
Dropout
<
T
>::
Config
(
hidden_output_dropout_ratio
,
_hidden_size
)),
_attn_scores
(
typename
StridedBatchGemm
<
T
>::
Config
(
_batch_size
*
_heads
,
_seq_length
,
_seq_length
,
_hidden_size
/
_heads
,
//(T(1.0) / T(sqrt(_hidden_size / _heads))),
//aiss debug 0506
(
T
(
1.0
/
(
sqrt
(
_hidden_size
/
_heads
)))),
T
(
0.0
),
rocblas_operation_transpose
,
rocblas_operation_none
,
gemm_algos
[
3
])),
_attn_context
(
typename
StridedBatchGemm
<
T
>::
Config
(
_batch_size
*
_heads
,
_hidden_size
/
_heads
,
_seq_length
,
_seq_length
,
T
(
1.0
),
T
(
0.0
),
rocblas_operation_none
,
rocblas_operation_none
,
gemm_algos
[
4
]))
{
assert
(
_hidden_size
%
_heads
==
0
);
Initialize
();
}
template
<
typename
T
>
BertTransformerLayer
<
T
>::~
BertTransformerLayer
()
{
}
template
<
typename
T
>
void
BertTransformerLayer
<
T
>::
Initialize
()
{
#ifndef __HIP_PLATFORM_HCC__
if
(
std
::
is_same
<
T
,
__half
>::
value
)
rocblas_set_math_mode
(
_cublasHandle
,
CUBLAS_TENSOR_OP_MATH
);
#endif
}
template
<
typename
T
>
void
BertTransformerLayer
<
T
>::
Forward
(
unsigned
bsz
,
const
T
*
input_ptr
,
const
T
*
input_mask_ptr
,
const
T
*
attn_qkvw_ptr
,
const
T
*
attn_qkvb_ptr
,
const
T
*
attn_ow_ptr
,
const
T
*
attn_ob_ptr
,
const
T
*
attn_nw_ptr
,
const
T
*
attn_nb_ptr
,
const
T
*
inter_w_ptr
,
const
T
*
inter_b_ptr
,
const
T
*
output_w_ptr
,
const
T
*
output_b_ptr
,
const
T
*
norm_w_ptr
,
const
T
*
norm_b_ptr
,
T
*
out_ptr
,
T
*
inp_norm_ptr
,
T
*
q_tf_ptr
,
T
*
k_tf_ptr
,
T
*
v_tf_ptr
,
T
*
soft_out_ptr
,
T
*
ctx_bufB_ptr
,
T
*
attn_o_inp_ptr
,
T
*
add_res_ptr
,
T
*
ff1_inp_ptr
,
T
*
gelu_inp_ptr
,
T
*
ff2_inp_ptr
)
{
rocblas_set_stream
(
_cublasHandle
,
_stream
);
if
(
!
_stochastic_mode
)
hipStreamSynchronize
(
_stream
);
T
*
workspace
=
static_cast
<
T
*>
(
Context
::
Instance
().
GetWorkSpace
());
size_t
small_buf_size
=
bsz
*
_seq_length
*
_hidden_size
;
T
*
buf_0
=
workspace
;
T
*
buf_1
=
buf_0
+
small_buf_size
;
T
*
buf_2
=
buf_1
;
if
(
_normalize_invertible
)
{
add_res_ptr
=
buf_1
+
3
*
small_buf_size
;
buf_2
=
add_res_ptr
;
}
if
(
_gelu_checkpoint
)
buf_2
+=
small_buf_size
;
if
(
_attn_dropout_checkpoint
)
ctx_bufB_ptr
=
(
_gelu_checkpoint
?
(
buf_2
+
(
_intermediate_size
/
_hidden_size
)
*
small_buf_size
)
:
(
buf_1
+
4
*
small_buf_size
));
int
bsz_seq
=
bsz
*
_seq_length
;
if
(
_pre_or_postLayerNorm
)
{
if
(
_layer_norm
.
UseMean
())
_layer_norm
.
ForwardCheckpoint
(
bsz_seq
,
inp_norm_ptr
,
input_ptr
,
norm_w_ptr
,
norm_b_ptr
,
_stream
,
true
);
else
_layer_norm
.
Forward
(
bsz_seq
,
inp_norm_ptr
,
input_ptr
,
norm_w_ptr
,
norm_b_ptr
,
_stream
,
true
);
}
if
(
_pre_or_postLayerNorm
)
_qkv_linear
.
Forward
(
bsz_seq
,
inp_norm_ptr
,
attn_qkvw_ptr
,
buf_0
,
_cublasHandle
);
else
_qkv_linear
.
Forward
(
bsz_seq
,
input_ptr
,
attn_qkvw_ptr
,
buf_0
,
_cublasHandle
);
launch_bias_add_transform_0213
<
T
>
(
q_tf_ptr
,
buf_0
,
attn_qkvb_ptr
,
bsz
,
_seq_length
,
_hidden_size
,
_heads
,
_stream
,
3
);
int
bsz_heads
=
bsz
*
_heads
;
// attention scores
_attn_scores
.
Forward
(
bsz_heads
,
soft_out_ptr
,
k_tf_ptr
,
q_tf_ptr
,
_cublasHandle
);
// Softmax + Mask
_softmax
.
Forward
(
bsz
,
soft_out_ptr
,
input_mask_ptr
,
_stream
);
// attn prob dropout.
_attn_prob_dropout
.
Forward
(
bsz_heads
*
_seq_length
,
ctx_bufB_ptr
,
soft_out_ptr
,
_stream
);
// attention context
_attn_context
.
Forward
(
bsz_heads
,
buf_1
,
v_tf_ptr
,
ctx_bufB_ptr
,
_cublasHandle
);
launch_transform4d_0213
<
T
>
(
attn_o_inp_ptr
,
buf_1
,
bsz
,
_heads
,
_seq_length
,
_hidden_size
,
_stream
,
1
);
if
(
_pre_or_postLayerNorm
)
_attn_out_linear
.
Forward
(
bsz_seq
,
attn_o_inp_ptr
,
attn_ow_ptr
,
buf_1
,
_cublasHandle
);
else
_attn_out_linear
.
Forward
(
bsz_seq
,
attn_o_inp_ptr
,
attn_ow_ptr
,
ff1_inp_ptr
,
_cublasHandle
);
// attn output dropout.
if
(
_pre_or_postLayerNorm
)
_attn_output_dropout
.
ForwardWithBias
(
bsz_seq
,
add_res_ptr
,
buf_1
,
input_ptr
,
attn_ob_ptr
,
_stream
);
else
_attn_output_dropout
.
ForwardWithBias
(
bsz_seq
,
add_res_ptr
,
ff1_inp_ptr
,
input_ptr
,
attn_ob_ptr
,
_stream
);
if
(
_pre_or_postLayerNorm
)
{
if
(
_attn_layer_norm
.
UseMean
())
_attn_layer_norm
.
ForwardCheckpoint
(
bsz_seq
,
ff1_inp_ptr
,
add_res_ptr
,
attn_nw_ptr
,
attn_nb_ptr
,
_stream
,
true
);
else
_attn_layer_norm
.
Forward
(
bsz_seq
,
ff1_inp_ptr
,
add_res_ptr
,
attn_nw_ptr
,
attn_nb_ptr
,
_stream
,
true
);
}
else
{
if
(
_attn_layer_norm
.
UseMean
())
_attn_layer_norm
.
ForwardCheckpoint
(
bsz_seq
,
ff1_inp_ptr
,
add_res_ptr
,
attn_nw_ptr
,
attn_nb_ptr
,
_stream
,
true
);
else
_attn_layer_norm
.
Forward
(
bsz_seq
,
ff1_inp_ptr
,
add_res_ptr
,
attn_nw_ptr
,
attn_nb_ptr
,
_stream
,
true
);
}
_ff1
.
Forward
(
bsz_seq
,
ff1_inp_ptr
,
inter_w_ptr
,
(
_gelu_checkpoint
?
ff2_inp_ptr
:
gelu_inp_ptr
),
_cublasHandle
);
_gelu
.
ForwardWithBiasAdd
(
bsz_seq
,
(
_gelu_checkpoint
?
ff2_inp_ptr
:
gelu_inp_ptr
),
inter_b_ptr
,
(
_gelu_checkpoint
?
buf_2
:
ff2_inp_ptr
),
_stream
);
_ff2
.
Forward
(
bsz_seq
,
(
_gelu_checkpoint
?
buf_2
:
ff2_inp_ptr
),
output_w_ptr
,
out_ptr
,
_cublasHandle
);
// layer output dropout.
if
(
_pre_or_postLayerNorm
)
_layer_output_dropout
.
ForwardWithBias
(
bsz_seq
,
out_ptr
,
out_ptr
,
add_res_ptr
,
output_b_ptr
,
_stream
);
else
_layer_output_dropout
.
ForwardWithBias
(
bsz_seq
,
inp_norm_ptr
,
out_ptr
,
ff1_inp_ptr
,
output_b_ptr
,
_stream
);
if
(
!
_pre_or_postLayerNorm
)
{
if
(
_layer_norm
.
UseMean
())
_layer_norm
.
ForwardCheckpoint
(
bsz_seq
,
out_ptr
,
inp_norm_ptr
,
norm_w_ptr
,
norm_b_ptr
,
_stream
,
true
);
else
_layer_norm
.
Forward
(
bsz_seq
,
out_ptr
,
inp_norm_ptr
,
norm_w_ptr
,
norm_b_ptr
,
_stream
,
true
);
}
}
template
<
typename
T
>
void
BertTransformerLayer
<
T
>::
Backward
(
unsigned
bsz
,
const
T
*
grad_output_ptr
,
const
T
*
input_ptr
,
const
T
*
output_ptr
,
const
T
*
inp_norm_ptr
,
const
T
*
q_tf_ptr
,
const
T
*
k_tf_ptr
,
const
T
*
v_tf_ptr
,
const
T
*
soft_out_ptr
,
const
T
*
ctx_bufB_ptr
,
const
T
*
attn_o_inp_ptr
,
const
T
*
add_res_ptr
,
const
T
*
ff1_inp_ptr
,
const
T
*
gelu_inp_ptr
,
const
T
*
ff2_inp_ptr
,
const
T
*
input_mask_ptr
,
const
T
*
attn_qkvw_ptr
,
const
T
*
attn_ow_ptr
,
const
T
*
attn_nw_ptr
,
const
T
*
attn_nb_ptr
,
const
T
*
inter_w_ptr
,
const
T
*
inter_b_ptr
,
const
T
*
output_w_ptr
,
const
T
*
norm_w_ptr
,
const
T
*
norm_b_ptr
,
T
*
grad_input_ptr
,
T
*
grad_attn_qkvw_ptr
,
T
*
grad_attn_qkvb_ptr
,
T
*
grad_attn_ow_ptr
,
T
*
grad_attn_ob_ptr
,
T
*
grad_attn_nw_ptr
,
T
*
grad_attn_nb_ptr
,
T
*
grad_inter_w_ptr
,
T
*
grad_inter_b_ptr
,
T
*
grad_output_w_ptr
,
T
*
grad_output_b_ptr
,
T
*
grad_norm_w_ptr
,
T
*
grad_norm_b_ptr
)
{
rocblas_set_stream
(
_cublasHandle
,
_stream
);
if
(
!
_stochastic_mode
)
hipStreamSynchronize
(
_stream
);
T
*
workspace
=
static_cast
<
T
*>
(
Context
::
Instance
().
GetWorkSpace
());
size_t
small_buf_size
=
bsz
*
_seq_length
*
_hidden_size
;
T
*
buf_0
=
workspace
;
T
*
buf_1
=
buf_0
+
small_buf_size
;
T
*
buf_2
=
buf_1
+
small_buf_size
;
T
*
buf_3
=
buf_2
+
small_buf_size
;
T
*
ff2_buf
=
(
_gelu_checkpoint
?
buf_3
+
(
bsz
*
_seq_length
*
_intermediate_size
)
:
buf_3
+
small_buf_size
);
T
*
ctx_bufB_ptr_recomp
=
ff2_buf
+
(
_seq_length
*
_seq_length
*
bsz
*
_heads
);
hipStream_t
streams
[
2
]
=
{
_stream
,
_stream
};
int
bsz_seq
=
bsz
*
_seq_length
;
int
bsz_heads
=
bsz
*
_heads
;
if
(
!
_pre_or_postLayerNorm
)
{
if
(
_layer_norm
.
UseMean
())
_layer_norm
.
Backward
(
bsz_seq
,
grad_output_ptr
,
norm_w_ptr
,
grad_norm_w_ptr
,
grad_norm_b_ptr
,
streams
,
buf_1
,
inp_norm_ptr
);
else
_layer_norm
.
Backward
(
bsz_seq
,
grad_output_ptr
,
norm_w_ptr
,
norm_b_ptr
,
grad_norm_w_ptr
,
grad_norm_b_ptr
,
streams
,
buf_1
,
output_ptr
);
}
if
(
_pre_or_postLayerNorm
)
_layer_output_dropout
.
Backward
(
bsz_seq
,
buf_0
,
grad_output_ptr
,
_stream
);
else
_layer_output_dropout
.
Backward
(
bsz_seq
,
buf_0
,
buf_1
,
_stream
);
const
T
*
layer_dropout_buf
=
_layer_output_dropout
.
HasDropout
()
?
buf_0
:
(
_pre_or_postLayerNorm
?
grad_output_ptr
:
buf_1
);
if
(
_gelu_checkpoint
)
_gelu
.
ForwardWithBiasAdd
(
bsz_seq
,
ff2_inp_ptr
,
inter_b_ptr
,
buf_2
,
_stream
);
_ff2
.
Backward
(
bsz_seq
,
layer_dropout_buf
,
(
_gelu_checkpoint
?
buf_2
:
ff2_inp_ptr
),
output_w_ptr
,
grad_output_w_ptr
,
grad_output_b_ptr
,
_cublasHandle
,
_stream
,
ff2_buf
);
_gelu
.
Backward
(
bsz_seq
,
ff2_buf
,
(
_gelu_checkpoint
?
ff2_inp_ptr
:
gelu_inp_ptr
),
inter_b_ptr
,
_stream
);
_ff1
.
Backward
(
bsz_seq
,
ff2_buf
,
ff1_inp_ptr
,
inter_w_ptr
,
grad_inter_w_ptr
,
grad_inter_b_ptr
,
_cublasHandle
,
_stream
,
buf_3
);
if
(
!
_pre_or_postLayerNorm
)
launch_fused_add2
<
T
>
(
buf_2
,
buf_3
,
buf_1
,
bsz
,
_seq_length
,
_hidden_size
,
_stream
);
if
(
_pre_or_postLayerNorm
)
{
if
(
_attn_layer_norm
.
UseMean
())
_attn_layer_norm
.
BackwardFusedAdd
(
bsz_seq
,
buf_3
,
grad_output_ptr
,
attn_nw_ptr
,
grad_attn_nw_ptr
,
grad_attn_nb_ptr
,
streams
,
buf_0
,
add_res_ptr
);
else
_attn_layer_norm
.
BackwardFusedAdd
(
bsz_seq
,
buf_3
,
grad_output_ptr
,
attn_nw_ptr
,
attn_nb_ptr
,
grad_attn_nw_ptr
,
grad_attn_nb_ptr
,
streams
,
buf_0
,
ff1_inp_ptr
);
}
else
{
if
(
_attn_layer_norm
.
UseMean
())
_attn_layer_norm
.
Backward
(
bsz_seq
,
buf_2
,
attn_nw_ptr
,
grad_attn_nw_ptr
,
grad_attn_nb_ptr
,
streams
,
buf_0
,
add_res_ptr
);
else
_attn_layer_norm
.
Backward
(
bsz_seq
,
buf_2
,
attn_nw_ptr
,
attn_nb_ptr
,
grad_attn_nw_ptr
,
grad_attn_nb_ptr
,
streams
,
buf_0
,
ff1_inp_ptr
);
}
_attn_output_dropout
.
Backward
(
bsz_seq
,
buf_2
,
buf_0
,
_stream
);
T
*
attn_output_dropout_buf
=
_attn_output_dropout
.
HasDropout
()
?
buf_2
:
buf_0
;
_attn_out_linear
.
Backward
(
bsz_seq
,
attn_output_dropout_buf
,
attn_o_inp_ptr
,
attn_ow_ptr
,
grad_attn_ow_ptr
,
grad_attn_ob_ptr
,
_cublasHandle
,
_stream
,
buf_1
);
launch_transform_0213
<
T
>
(
buf_2
,
buf_1
,
bsz
,
_seq_length
,
_hidden_size
,
_heads
,
_stream
);
if
(
_attn_prob_dropout
.
HasDropout
())
{
if
(
_attn_dropout_checkpoint
)
_attn_prob_dropout
.
Forward
(
bsz_heads
*
_seq_length
,
ctx_bufB_ptr_recomp
,
soft_out_ptr
,
_stream
,
true
);
_attn_context
.
Backward
(
bsz_heads
,
buf_2
,
v_tf_ptr
,
(
_attn_dropout_checkpoint
?
ctx_bufB_ptr_recomp
:
ctx_bufB_ptr
),
_cublasHandle
,
buf_3
,
ff2_buf
);
}
else
_attn_context
.
Backward
(
bsz_heads
,
buf_2
,
v_tf_ptr
,
soft_out_ptr
,
_cublasHandle
,
buf_3
,
ff2_buf
);
_attn_prob_dropout
.
Backward
(
bsz_heads
*
_seq_length
,
ff2_buf
,
_stream
);
_softmax
.
Backward
(
bsz
,
ff2_buf
,
soft_out_ptr
,
_stream
);
_attn_scores
.
Backward
(
bsz_heads
,
ff2_buf
,
k_tf_ptr
,
q_tf_ptr
,
_cublasHandle
,
buf_2
,
buf_1
);
launch_transform4d_0213
(
ff2_buf
,
buf_1
,
bsz
,
_heads
,
_seq_length
,
_hidden_size
,
_stream
,
3
);
if
(
_pre_or_postLayerNorm
)
_qkv_linear
.
Backward
(
bsz_seq
,
ff2_buf
,
inp_norm_ptr
,
attn_qkvw_ptr
,
grad_attn_qkvw_ptr
,
grad_attn_qkvb_ptr
,
_cublasHandle
,
_stream
,
buf_2
);
else
_qkv_linear
.
Backward
(
bsz_seq
,
ff2_buf
,
input_ptr
,
attn_qkvw_ptr
,
grad_attn_qkvw_ptr
,
grad_attn_qkvb_ptr
,
_cublasHandle
,
_stream
,
buf_2
);
if
(
_pre_or_postLayerNorm
)
{
if
(
_layer_norm
.
UseMean
())
_layer_norm
.
BackwardFusedAdd
(
bsz_seq
,
buf_2
,
buf_0
,
norm_w_ptr
,
grad_norm_w_ptr
,
grad_norm_b_ptr
,
streams
,
grad_input_ptr
,
input_ptr
);
else
_layer_norm
.
BackwardFusedAdd
(
bsz_seq
,
buf_2
,
buf_0
,
norm_w_ptr
,
norm_b_ptr
,
grad_norm_w_ptr
,
grad_norm_b_ptr
,
streams
,
grad_input_ptr
,
inp_norm_ptr
);
}
else
launch_fused_add2
<
T
>
(
grad_input_ptr
,
buf_2
,
buf_0
,
bsz
,
_seq_length
,
_hidden_size
,
_stream
);
}
template
<
typename
T
>
void
BertTransformerLayer
<
T
>::
SetTrainingMode
(
bool
training
)
{
// Dropout will be skipped when not in training model.
_attn_prob_dropout
.
SetTrainingMode
(
training
);
_attn_output_dropout
.
SetTrainingMode
(
training
);
_layer_output_dropout
.
SetTrainingMode
(
training
);
}
template
<
typename
T
>
void
BertTransformerLayer
<
T
>::
SetIntermediateBuffers
(
uint8_t
*
attn_prob_dropout_mask_ptr
,
uint8_t
*
attn_output_dropout_mask_ptr
,
uint8_t
*
layer_output_dropout_mask_ptr
,
T
*
attn_layer_norm_var
,
T
*
attn_layer_norm_mean
,
T
*
layer_norm_var
,
T
*
layer_norm_mean
)
{
_attn_prob_dropout
.
SetMask
(
attn_prob_dropout_mask_ptr
);
_attn_output_dropout
.
SetMask
(
attn_output_dropout_mask_ptr
);
_layer_output_dropout
.
SetMask
(
layer_output_dropout_mask_ptr
);
_attn_layer_norm
.
SetVar
(
attn_layer_norm_var
);
_attn_layer_norm
.
SetMean
(
attn_layer_norm_mean
);
_layer_norm
.
SetVar
(
layer_norm_var
);
_layer_norm
.
SetMean
(
layer_norm_mean
);
}
template
<
typename
T
>
void
BertTransformerLayer
<
T
>::
SetSeqLength
(
unsigned
seq_len
)
{
_seq_length
=
seq_len
;
_softmax
.
SetSeqLength
(
_seq_length
);
_attn_prob_dropout
.
SetDimension
(
_seq_length
);
_attn_scores
.
SetConfig
(
_seq_length
,
_seq_length
,
_hidden_size
/
_heads
);
_attn_context
.
SetConfig
(
_hidden_size
/
_heads
,
_seq_length
,
_seq_length
);
}
template
<
typename
T
>
int
create_transformer_layer
(
unsigned
layer_id
,
unsigned
batch_size
,
unsigned
hidden_dim
,
unsigned
num_heads
,
unsigned
intermediate_size
,
float
attn_dropout_ratio
,
float
hidden_dropout_ratio
,
float
layer_norm_eps
,
int
seed
,
bool
pre_or_postLayerNorm
,
bool
test_gemm
,
bool
attn_dropout_checkpoint
,
bool
normalize_invertible
,
bool
gelu_checkpoint
,
bool
stochastic_mode
)
{
Context
::
Instance
().
SetSeed
(
seed
);
Context
::
Instance
().
TestGemmFP16
(
test_gemm
,
batch_size
,
init_seq_length
,
num_heads
,
hidden_dim
/
num_heads
);
auto
layer
=
std
::
make_shared
<
BertTransformerLayer
<
T
>>
(
layer_id
,
batch_size
,
hidden_dim
,
num_heads
,
intermediate_size
,
init_seq_length
,
attn_dropout_ratio
,
hidden_dropout_ratio
,
layer_norm_eps
,
pre_or_postLayerNorm
,
Context
::
Instance
().
GetGemmAlgos
(),
attn_dropout_checkpoint
,
normalize_invertible
,
gelu_checkpoint
,
stochastic_mode
);
s_transformer_layers
[
layer_id
]
=
layer
;
std
::
string
dtype
=
(
std
::
is_same
<
T
,
__half
>::
value
)
?
"half"
:
"float"
;
std
::
cout
<<
"layer #"
<<
layer_id
<<
" is created with date type ["
<<
dtype
<<
"]."
<<
std
::
endl
;
return
0
;
}
template
<
typename
T
>
std
::
vector
<
torch
::
Tensor
>
ds_transformer_forward
(
unsigned
layer_id
,
const
torch
::
Tensor
&
input
,
const
torch
::
Tensor
&
input_mask
,
const
torch
::
Tensor
&
attn_qkvw
,
const
torch
::
Tensor
&
attn_qkvb
,
const
torch
::
Tensor
&
attn_ow
,
const
torch
::
Tensor
&
attn_ob
,
const
torch
::
Tensor
&
attn_nw
,
const
torch
::
Tensor
&
attn_nb
,
const
torch
::
Tensor
&
inter_w
,
const
torch
::
Tensor
&
inter_b
,
const
torch
::
Tensor
&
output_w
,
const
torch
::
Tensor
&
output_b
,
const
torch
::
Tensor
&
norm_w
,
const
torch
::
Tensor
&
norm_b
,
bool
training_mode
,
bool
prelayernorm
,
bool
attn_dropout_checkpoint
,
bool
normalize_invertible
,
bool
gelu_checkpoint
)
{
CHECK_INPUT
(
input
);
CHECK_INPUT
(
input_mask
);
CHECK_INPUT
(
attn_qkvw
);
CHECK_INPUT
(
attn_qkvb
);
CHECK_INPUT
(
attn_ow
);
CHECK_INPUT
(
attn_ob
);
CHECK_INPUT
(
attn_nw
);
CHECK_INPUT
(
attn_nb
);
CHECK_INPUT
(
inter_w
);
CHECK_INPUT
(
inter_b
);
CHECK_INPUT
(
output_w
);
CHECK_INPUT
(
output_b
);
CHECK_INPUT
(
norm_w
);
CHECK_INPUT
(
norm_b
);
unsigned
bsz
=
input
.
size
(
0
);
const
T
*
input_ptr
=
(
const
T
*
)
input
.
data_ptr
();
const
T
*
input_mask_ptr
=
(
const
T
*
)
input_mask
.
data_ptr
();
const
T
*
attn_qkvw_ptr
=
(
const
T
*
)
attn_qkvw
.
data_ptr
();
const
T
*
attn_qkvb_ptr
=
(
const
T
*
)
attn_qkvb
.
data_ptr
();
const
T
*
attn_ow_ptr
=
(
const
T
*
)
attn_ow
.
data_ptr
();
const
T
*
attn_ob_ptr
=
(
const
T
*
)
attn_ob
.
data_ptr
();
const
T
*
attn_nw_ptr
=
(
const
T
*
)
attn_nw
.
data_ptr
();
const
T
*
attn_nb_ptr
=
(
const
T
*
)
attn_nb
.
data_ptr
();
const
T
*
inter_w_ptr
=
(
const
T
*
)
inter_w
.
data_ptr
();
const
T
*
inter_b_ptr
=
(
const
T
*
)
inter_b
.
data_ptr
();
const
T
*
output_w_ptr
=
(
const
T
*
)
output_w
.
data_ptr
();
const
T
*
output_b_ptr
=
(
const
T
*
)
output_b
.
data_ptr
();
const
T
*
norm_w_ptr
=
(
const
T
*
)
norm_w
.
data_ptr
();
const
T
*
norm_b_ptr
=
(
const
T
*
)
norm_b
.
data_ptr
();
auto
output
=
torch
::
empty_like
(
input
);
T
*
out_ptr
=
(
T
*
)
output
.
data_ptr
();
auto
options
=
torch
::
TensorOptions
()
.
dtype
(
input
.
options
().
dtype
())
.
layout
(
torch
::
kStrided
)
.
device
(
torch
::
kCUDA
)
.
requires_grad
(
true
);
auto
uint8_options
=
torch
::
TensorOptions
()
.
dtype
(
torch
::
kInt8
)
.
layout
(
torch
::
kStrided
)
.
device
(
torch
::
kCUDA
)
.
requires_grad
(
false
);
std
::
shared_ptr
<
BertTransformerLayer
<
T
>>
layer
=
std
::
static_pointer_cast
<
BertTransformerLayer
<
T
>>
(
s_transformer_layers
[
layer_id
]);
unsigned
seq_len
=
layer
->
GetSeqLength
();
if
(
input
.
size
(
1
)
!=
seq_len
)
{
seq_len
=
input
.
size
(
1
);
layer
->
SetSeqLength
(
seq_len
);
}
auto
workspace
=
torch
::
empty
({
get_workspace_size
<
T
>
(
bsz
,
seq_len
,
layer
->
GetHiddenSize
(),
layer
->
GetIntermediateSize
(),
layer
->
GetNumHeads
(),
layer
->
IsTrainingMode
(),
layer
->
GeluCheckpoint
())},
options
);
Context
::
Instance
().
SetWorkSpace
((
T
*
)
workspace
.
data_ptr
());
auto
inp_norm
=
((
prelayernorm
||
!
normalize_invertible
)
?
torch
::
empty_like
(
input
)
:
output
);
auto
add_res
=
(
normalize_invertible
?
inp_norm
:
torch
::
empty_like
(
input
));
auto
attn_o_inp
=
torch
::
empty_like
(
input
);
auto
qkv_tf
=
torch
::
empty
({(
bsz
*
seq_len
),
output_w
.
size
(
0
)
*
3
},
options
);
auto
attn_prob_dropout_mask
=
torch
::
empty
({(
bsz
*
layer
->
GetNumHeads
()
*
seq_len
),
seq_len
},
uint8_options
);
auto
attn_output_dropout_mask
=
torch
::
empty
({(
bsz
*
seq_len
),
layer
->
GetHiddenSize
()},
uint8_options
);
auto
layer_output_dropout_mask
=
torch
::
empty
({(
bsz
*
seq_len
),
layer
->
GetHiddenSize
()},
uint8_options
);
auto
attn_layer_norm_var
=
torch
::
empty
({(
bsz
*
seq_len
)},
options
);
auto
attn_layer_norm_mean
=
torch
::
empty
({(
bsz
*
seq_len
)},
options
);
auto
layer_norm_var
=
torch
::
empty
({(
bsz
*
seq_len
)},
options
);
auto
layer_norm_mean
=
torch
::
empty
({(
bsz
*
seq_len
)},
options
);
T
*
inp_norm_ptr
=
(
T
*
)
inp_norm
.
data_ptr
();
T
*
add_res_ptr
=
(
T
*
)
add_res
.
data_ptr
();
T
*
q_tf_ptr
=
(
T
*
)
qkv_tf
.
data_ptr
();
T
*
k_tf_ptr
=
q_tf_ptr
+
(
bsz
*
seq_len
*
output_w
.
size
(
0
));
//(T*)k_tf.data_ptr();
T
*
v_tf_ptr
=
k_tf_ptr
+
(
bsz
*
seq_len
*
output_w
.
size
(
0
));
//(T*)v_tf.data_ptr();
T
*
attn_o_inp_ptr
=
(
T
*
)
attn_o_inp
.
data_ptr
();
torch
::
Tensor
ff2_inp
=
torch
::
empty
({(
bsz
*
seq_len
),
output_w
.
size
(
1
)},
options
);
torch
::
Tensor
gelu_inp
=
(
gelu_checkpoint
?
ff2_inp
:
torch
::
empty
({(
bsz
*
seq_len
),
output_w
.
size
(
1
)},
options
));
auto
ff1_inp
=
torch
::
empty_like
(
input
);
T
*
ff2_inp_ptr
=
(
T
*
)
ff2_inp
.
data_ptr
();
T
*
gelu_inp_ptr
=
(
T
*
)
gelu_inp
.
data_ptr
();
T
*
ff1_inp_ptr
=
(
T
*
)
ff1_inp
.
data_ptr
();
torch
::
Tensor
soft_out
=
torch
::
empty
({(
bsz
*
layer
->
GetNumHeads
()
*
seq_len
),
seq_len
},
options
);
torch
::
Tensor
ctx_bufB
=
(
attn_dropout_checkpoint
?
soft_out
:
torch
::
empty
({(
bsz
*
layer
->
GetNumHeads
()
*
seq_len
),
seq_len
},
options
));
T
*
soft_out_ptr
=
(
T
*
)
soft_out
.
data_ptr
();
T
*
ctx_bufB_ptr
=
(
T
*
)
ctx_bufB
.
data_ptr
();
layer
->
SetTrainingMode
(
training_mode
);
layer
->
SetIntermediateBuffers
((
uint8_t
*
)
attn_prob_dropout_mask
.
data_ptr
(),
(
uint8_t
*
)
attn_output_dropout_mask
.
data_ptr
(),
(
uint8_t
*
)
layer_output_dropout_mask
.
data_ptr
(),
(
T
*
)
attn_layer_norm_var
.
data_ptr
(),
(
T
*
)
attn_layer_norm_mean
.
data_ptr
(),
(
T
*
)
layer_norm_var
.
data_ptr
(),
(
T
*
)
layer_norm_mean
.
data_ptr
());
layer
->
Forward
(
bsz
,
input_ptr
,
input_mask_ptr
,
attn_qkvw_ptr
,
attn_qkvb_ptr
,
attn_ow_ptr
,
attn_ob_ptr
,
attn_nw_ptr
,
attn_nb_ptr
,
inter_w_ptr
,
inter_b_ptr
,
output_w_ptr
,
output_b_ptr
,
norm_w_ptr
,
norm_b_ptr
,
out_ptr
,
inp_norm_ptr
,
q_tf_ptr
,
k_tf_ptr
,
v_tf_ptr
,
soft_out_ptr
,
ctx_bufB_ptr
,
attn_o_inp_ptr
,
add_res_ptr
,
ff1_inp_ptr
,
gelu_inp_ptr
,
ff2_inp_ptr
);
return
{
output
,
inp_norm
,
qkv_tf
,
soft_out
,
ctx_bufB
,
attn_o_inp
,
add_res
,
ff1_inp
,
gelu_inp
,
ff2_inp
,
attn_prob_dropout_mask
,
attn_output_dropout_mask
,
layer_output_dropout_mask
,
attn_layer_norm_var
,
attn_layer_norm_mean
,
layer_norm_var
,
layer_norm_mean
};
}
template
<
typename
T
>
std
::
vector
<
torch
::
Tensor
>
ds_transformer_backward
(
unsigned
layer_id
,
const
torch
::
Tensor
&
grad_output
,
const
torch
::
Tensor
&
output
,
const
torch
::
Tensor
&
inp_norm
,
const
torch
::
Tensor
&
qkv_tf
,
const
torch
::
Tensor
&
soft_out
,
const
torch
::
Tensor
&
ctx_bufB
,
const
torch
::
Tensor
&
attn_o_inp
,
const
torch
::
Tensor
&
add_res
,
const
torch
::
Tensor
&
ff1_inp
,
const
torch
::
Tensor
&
gelu_inp
,
const
torch
::
Tensor
&
ff2_inp
,
const
torch
::
Tensor
&
attn_prob_dropout_mask
,
const
torch
::
Tensor
&
attn_output_dropout_mask
,
const
torch
::
Tensor
&
layer_output_dropout_mask
,
const
torch
::
Tensor
&
attn_layer_norm_var
,
const
torch
::
Tensor
&
attn_layer_norm_mean
,
const
torch
::
Tensor
&
layer_norm_var
,
const
torch
::
Tensor
&
layer_norm_mean
,
const
torch
::
Tensor
&
input
,
const
torch
::
Tensor
&
input_mask
,
const
torch
::
Tensor
&
attn_qkvw
,
const
torch
::
Tensor
&
attn_qkvb
,
const
torch
::
Tensor
&
attn_ow
,
const
torch
::
Tensor
&
attn_ob
,
const
torch
::
Tensor
&
attn_nw
,
const
torch
::
Tensor
&
attn_nb
,
const
torch
::
Tensor
&
inter_w
,
const
torch
::
Tensor
&
inter_b
,
const
torch
::
Tensor
&
output_w
,
const
torch
::
Tensor
&
output_b
,
const
torch
::
Tensor
&
norm_w
,
const
torch
::
Tensor
&
norm_b
)
{
auto
g_output
=
grad_output
.
contiguous
();
CHECK_INPUT
(
g_output
);
CHECK_INPUT
(
output
);
CHECK_INPUT
(
inp_norm
);
CHECK_INPUT
(
qkv_tf
);
CHECK_INPUT
(
add_res
);
CHECK_INPUT
(
soft_out
);
CHECK_INPUT
(
ctx_bufB
);
CHECK_INPUT
(
attn_o_inp
);
CHECK_INPUT
(
ff1_inp
);
CHECK_INPUT
(
gelu_inp
);
CHECK_INPUT
(
ff2_inp
);
CHECK_INPUT
(
input
);
CHECK_INPUT
(
input_mask
);
CHECK_INPUT
(
attn_qkvw
);
CHECK_INPUT
(
attn_qkvb
);
CHECK_INPUT
(
attn_ow
);
CHECK_INPUT
(
attn_ob
);
CHECK_INPUT
(
attn_nw
);
CHECK_INPUT
(
attn_nb
);
CHECK_INPUT
(
inter_w
);
CHECK_INPUT
(
inter_b
);
CHECK_INPUT
(
output_w
);
CHECK_INPUT
(
output_b
);
CHECK_INPUT
(
norm_w
);
CHECK_INPUT
(
norm_b
);
unsigned
bsz
=
g_output
.
size
(
0
);
std
::
shared_ptr
<
BertTransformerLayer
<
T
>>
layer
=
std
::
static_pointer_cast
<
BertTransformerLayer
<
T
>>
(
s_transformer_layers
[
layer_id
]);
unsigned
seq_len
=
layer
->
GetSeqLength
();
if
(
g_output
.
size
(
1
)
!=
seq_len
)
{
seq_len
=
g_output
.
size
(
1
);
layer
->
SetSeqLength
(
seq_len
);
}
auto
options
=
torch
::
TensorOptions
()
.
dtype
(
g_output
.
options
().
dtype
())
.
layout
(
torch
::
kStrided
)
.
device
(
torch
::
kCUDA
)
.
requires_grad
(
true
);
auto
workspace
=
torch
::
empty
({
get_workspace_size
<
T
>
(
bsz
,
seq_len
,
layer
->
GetHiddenSize
(),
layer
->
GetIntermediateSize
(),
layer
->
GetNumHeads
(),
layer
->
IsTrainingMode
(),
layer
->
GeluCheckpoint
())},
options
);
Context
::
Instance
().
SetWorkSpace
((
T
*
)
workspace
.
data_ptr
());
auto
grad_input
=
torch
::
empty_like
(
input
);
auto
grad_attn_qkvw
=
torch
::
empty_like
(
attn_qkvw
);
auto
grad_attn_qkvb
=
torch
::
empty_like
(
attn_qkvb
);
auto
grad_attn_ow
=
torch
::
empty_like
(
attn_ow
);
auto
grad_attn_ob
=
torch
::
empty_like
(
attn_ob
);
auto
grad_attn_nw
=
torch
::
empty_like
(
attn_nw
);
auto
grad_attn_nb
=
torch
::
empty_like
(
attn_nb
);
auto
grad_inter_w
=
torch
::
empty_like
(
inter_w
);
auto
grad_inter_b
=
torch
::
empty_like
(
inter_b
);
auto
grad_output_w
=
torch
::
empty_like
(
output_w
);
auto
grad_output_b
=
torch
::
empty_like
(
output_b
);
auto
grad_norm_w
=
torch
::
empty_like
(
norm_w
);
auto
grad_norm_b
=
torch
::
empty_like
(
norm_b
);
// inputs.
const
T
*
grad_output_ptr
=
(
const
T
*
)
g_output
.
data_ptr
();
const
T
*
input_ptr
=
(
const
T
*
)
input
.
data_ptr
();
const
T
*
output_ptr
=
(
const
T
*
)
output
.
data_ptr
();
const
T
*
inp_norm_ptr
=
(
const
T
*
)
inp_norm
.
data_ptr
();
const
T
*
q_tf_ptr
=
(
const
T
*
)
qkv_tf
.
data_ptr
();
const
T
*
add_res_ptr
=
(
const
T
*
)
add_res
.
data_ptr
();
const
T
*
k_tf_ptr
=
q_tf_ptr
+
(
bsz
*
layer
->
GetSeqLength
()
*
output_w
.
size
(
0
));
//(const T*)k_tf.data_ptr();
const
T
*
v_tf_ptr
=
k_tf_ptr
+
(
bsz
*
layer
->
GetSeqLength
()
*
output_w
.
size
(
0
));
//(const T*)v_tf.data_ptr();
const
T
*
ff1_inp_ptr
=
(
const
T
*
)
ff1_inp
.
data_ptr
();
const
T
*
gelu_inp_ptr
=
(
const
T
*
)
gelu_inp
.
data_ptr
();
const
T
*
ff2_inp_ptr
=
(
const
T
*
)
ff2_inp
.
data_ptr
();
const
T
*
ctx_bufB_ptr
=
(
const
T
*
)
ctx_bufB
.
data_ptr
();
const
T
*
soft_out_ptr
=
(
const
T
*
)
soft_out
.
data_ptr
();
const
T
*
attn_o_inp_ptr
=
(
const
T
*
)
attn_o_inp
.
data_ptr
();
const
T
*
input_mask_ptr
=
(
const
T
*
)
input_mask
.
data_ptr
();
const
T
*
attn_qkvw_ptr
=
(
const
T
*
)
attn_qkvw
.
data_ptr
();
const
T
*
attn_ow_ptr
=
(
const
T
*
)
attn_ow
.
data_ptr
();
const
T
*
attn_nw_ptr
=
(
const
T
*
)
attn_nw
.
data_ptr
();
const
T
*
attn_nb_ptr
=
(
const
T
*
)
attn_nb
.
data_ptr
();
const
T
*
inter_w_ptr
=
(
const
T
*
)
inter_w
.
data_ptr
();
const
T
*
inter_b_ptr
=
(
const
T
*
)
inter_b
.
data_ptr
();
const
T
*
output_w_ptr
=
(
const
T
*
)
output_w
.
data_ptr
();
const
T
*
norm_w_ptr
=
(
const
T
*
)
norm_w
.
data_ptr
();
const
T
*
norm_b_ptr
=
(
const
T
*
)
norm_b
.
data_ptr
();
// outputs.
T
*
grad_input_ptr
=
(
T
*
)
grad_input
.
data_ptr
();
T
*
grad_attn_qkvw_ptr
=
(
T
*
)
grad_attn_qkvw
.
data_ptr
();
T
*
grad_attn_qkvb_ptr
=
(
T
*
)
grad_attn_qkvb
.
data_ptr
();
T
*
grad_attn_ow_ptr
=
(
T
*
)
grad_attn_ow
.
data_ptr
();
T
*
grad_attn_ob_ptr
=
(
T
*
)
grad_attn_ob
.
data_ptr
();
T
*
grad_attn_nw_ptr
=
(
T
*
)
grad_attn_nw
.
data_ptr
();
T
*
grad_attn_nb_ptr
=
(
T
*
)
grad_attn_nb
.
data_ptr
();
T
*
grad_inter_w_ptr
=
(
T
*
)
grad_inter_w
.
data_ptr
();
T
*
grad_inter_b_ptr
=
(
T
*
)
grad_inter_b
.
data_ptr
();
T
*
grad_output_w_ptr
=
(
T
*
)
grad_output_w
.
data_ptr
();
T
*
grad_output_b_ptr
=
(
T
*
)
grad_output_b
.
data_ptr
();
T
*
grad_norm_w_ptr
=
(
T
*
)
grad_norm_w
.
data_ptr
();
T
*
grad_norm_b_ptr
=
(
T
*
)
grad_norm_b
.
data_ptr
();
layer
->
SetIntermediateBuffers
((
uint8_t
*
)
attn_prob_dropout_mask
.
data_ptr
(),
(
uint8_t
*
)
attn_output_dropout_mask
.
data_ptr
(),
(
uint8_t
*
)
layer_output_dropout_mask
.
data_ptr
(),
(
T
*
)
attn_layer_norm_var
.
data_ptr
(),
(
T
*
)
attn_layer_norm_mean
.
data_ptr
(),
(
T
*
)
layer_norm_var
.
data_ptr
(),
(
T
*
)
layer_norm_mean
.
data_ptr
());
layer
->
Backward
(
bsz
,
grad_output_ptr
,
input_ptr
,
output_ptr
,
inp_norm_ptr
,
q_tf_ptr
,
k_tf_ptr
,
v_tf_ptr
,
soft_out_ptr
,
ctx_bufB_ptr
,
attn_o_inp_ptr
,
add_res_ptr
,
ff1_inp_ptr
,
gelu_inp_ptr
,
ff2_inp_ptr
,
input_mask_ptr
,
attn_qkvw_ptr
,
attn_ow_ptr
,
attn_nw_ptr
,
attn_nb_ptr
,
inter_w_ptr
,
inter_b_ptr
,
output_w_ptr
,
norm_w_ptr
,
norm_b_ptr
,
grad_input_ptr
,
grad_attn_qkvw_ptr
,
grad_attn_qkvb_ptr
,
grad_attn_ow_ptr
,
grad_attn_ob_ptr
,
grad_attn_nw_ptr
,
grad_attn_nb_ptr
,
grad_inter_w_ptr
,
grad_inter_b_ptr
,
grad_output_w_ptr
,
grad_output_b_ptr
,
grad_norm_w_ptr
,
grad_norm_b_ptr
);
return
{
grad_input
,
grad_attn_qkvw
,
grad_attn_qkvb
,
grad_attn_ow
,
grad_attn_ob
,
grad_attn_nw
,
grad_attn_nb
,
grad_inter_w
,
grad_inter_b
,
grad_output_w
,
grad_output_b
,
grad_norm_w
,
grad_norm_b
};
}
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"forward_fp32"
,
&
ds_transformer_forward
<
float
>
,
"DeepSpeed Transformer forward with fp32 (CUDA)"
);
m
.
def
(
"forward_fp16"
,
&
ds_transformer_forward
<
__half
>
,
"DeepSpeed Transformer forward with fp16 (CUDA)"
);
m
.
def
(
"backward_fp32"
,
&
ds_transformer_backward
<
float
>
,
"DeepSpeed Transformer backward with fp32 (CUDA)"
);
m
.
def
(
"backward_fp16"
,
&
ds_transformer_backward
<
__half
>
,
"DeepSpeed Transformer backward with fp16 (CUDA)"
);
m
.
def
(
"create_transformer_layer_fp32"
,
&
create_transformer_layer
<
float
>
,
"Create DeepSpeed Transformer Transformer Layer with fp32 (CUDA)"
);
m
.
def
(
"create_transformer_layer_fp16"
,
&
create_transformer_layer
<
__half
>
,
"Create DeepSpeed Transformer Transformer Layer with fp16 (CUDA)"
);
}
csrc/transformer/gelu_kernels.hip
deleted
100644 → 0
View file @
7dd68788
// !!! This is a file automatically generated by hipify!!!
#include "hip/hip_runtime.h"
#include "custom_hip_layers.h"
inline __device__ float gelu(const float x)
{
const float sqrt_param = 0.79788456080286535587989211986876f;
const float mul_param = 0.044715;
return x * 0.5f * (1.0f + tanhf(sqrt_param * (x + mul_param * x * x * x)));
}
inline __device__ float d_gelu(const float x)
{
const float sqrt_param = 0.79788456080286535587989211986876f;
const float mul_param = 0.044715;
float x2mul = x * x * mul_param;
float tan_h = tanhf(sqrt_param * (x + x * x2mul));
float dg1 = 0.5f * (1.0f + tan_h);
float dg2 = x * 0.5f * sqrt_param * (1 - tan_h * tan_h);
float dg3 = dg2 * 3 * x2mul;
return (dg1 + dg2 + dg3);
}
/*
Fused bias add with GELU
Loads a vector of 4 elements each iteration, for stride
iterations. It was written with the intention to launch 256 thread
threadblocks, so to launch for bert-large, we would set ITERATIONS
to 4. This is currently done automatically as a heuristic, setting
the number of iterations as blocks of 1024.
For FP16, the values are loaded from memory as __half, but converted
to FP32 for the arithmetic itself, to prevent numerous overflow on
the intermediate hyperbolic tangent, since there's no intrinsic
that computes it directly.
*/
__global__ void gelu_kernel(const float* input, float* vals, int row_stride, int iterations)
{
int row = blockIdx.x;
int id = threadIdx.x;
int loop_stride = blockDim.x;
const float4* input_cast = reinterpret_cast<const float4*>(input);
float4* vals_cast = reinterpret_cast<float4*>(vals);
for (int i = 0; i < iterations; i++) {
if (i * loop_stride + id < row_stride) {
float4 data = input_cast[row * row_stride + i * loop_stride + id];
data.x = gelu(data.x);
data.y = gelu(data.y);
data.z = gelu(data.z);
data.w = gelu(data.w);
vals_cast[row * row_stride + i * loop_stride + id] = data;
}
}
}
__global__ void gelu_kernel(const __half* input, __half* vals, int row_stride, int iterations)
{
#ifdef HALF_PRECISION_AVAILABLE
int row = blockIdx.x;
int id = threadIdx.x;
int loop_stride = blockDim.x;
const float2* input_cast = reinterpret_cast<const float2*>(input);
float2* vals_cast = reinterpret_cast<float2*>(vals);
for (int i = 0; i < iterations; i++) {
if (i * loop_stride + id < row_stride) {
float2 vals_vec = input_cast[row * row_stride + i * loop_stride + id];
__half2* vals_half = reinterpret_cast<__half2*>(&vals_vec);
float2 low_data = __half22float2(vals_half[0]);
float2 high_data = __half22float2(vals_half[1]);
low_data.x = gelu(low_data.x);
low_data.y = gelu(low_data.y);
high_data.x = gelu(high_data.x);
high_data.y = gelu(high_data.y);
vals_half[0] = __float22half2_rn(low_data);
vals_half[1] = __float22half2_rn(high_data);
vals_cast[row * row_stride + i * loop_stride + id] = vals_vec;
}
}
#endif
}
__global__ void fused_bias_gelu(const float* input,
const float* bias,
float* vals,
int row_stride,
int iterations)
{
int row = blockIdx.x;
int id = threadIdx.x;
int loop_stride = blockDim.x;
const float4* input_cast = reinterpret_cast<const float4*>(input);
float4* vals_cast = reinterpret_cast<float4*>(vals);
const float4* bias_cast = reinterpret_cast<const float4*>(bias);
for (int i = 0; i < iterations; i++) {
if (i * loop_stride + id < row_stride) {
float4 data = input_cast[row * row_stride + i * loop_stride + id];
float4 bias_data = bias_cast[i * loop_stride + id];
data.x += bias_data.x;
data.y += bias_data.y;
data.z += bias_data.z;
data.w += bias_data.w;
data.x = gelu(data.x);
data.y = gelu(data.y);
data.z = gelu(data.z);
data.w = gelu(data.w);
vals_cast[row * row_stride + i * loop_stride + id] = data;
}
}
}
__global__ void fused_bias_gelu(const __half* input,
const __half* bias,
__half* vals,
int row_stride,
int iterations)
{
#ifdef HALF_PRECISION_AVAILABLE
int row = blockIdx.x;
int id = threadIdx.x;
int loop_stride = blockDim.x;
const float2* input_cast = reinterpret_cast<const float2*>(input);
float2* vals_cast = reinterpret_cast<float2*>(vals);
const float2* bias_cast = reinterpret_cast<const float2*>(bias);
for (int i = 0; i < iterations; i++) {
if (i * loop_stride + id < row_stride) {
float2 vals_vec = input_cast[row * row_stride + i * loop_stride + id];
float2 bias_vec = bias_cast[i * loop_stride + id];
__half2* vals_half = reinterpret_cast<__half2*>(&vals_vec);
__half2* bias_half = reinterpret_cast<__half2*>(&bias_vec);
float2 low_data = __half22float2(vals_half[0]);
float2 high_data = __half22float2(vals_half[1]);
float2 low_bias = __half22float2(bias_half[0]);
float2 high_bias = __half22float2(bias_half[1]);
low_data.x += low_bias.x;
low_data.y += low_bias.y;
high_data.x += high_bias.x;
high_data.y += high_bias.y;
low_data.x = gelu(low_data.x);
low_data.y = gelu(low_data.y);
high_data.x = gelu(high_data.x);
high_data.y = gelu(high_data.y);
vals_half[0] = __float22half2_rn(low_data);
vals_half[1] = __float22half2_rn(high_data);
vals_cast[row * row_stride + i * loop_stride + id] = vals_vec;
}
}
#endif
}
__global__ void d_gelu_func(float* d_output,
const float* gelu_input,
const float* bias,
int row_stride,
int iterations)
{
int row = blockIdx.x;
int id = threadIdx.x;
int loop_stride = blockDim.x;
float4* d_output_cast = reinterpret_cast<float4*>(d_output);
const float4* gelu_input_cast = reinterpret_cast<const float4*>(gelu_input);
const float4* bias_cast = reinterpret_cast<const float4*>(bias);
for (int i = 0; i < iterations; i++) {
if (i * loop_stride + id < row_stride) {
float4 output_data = d_output_cast[row * row_stride + i * loop_stride + id];
float4 gelu_input_data = gelu_input_cast[row * row_stride + i * loop_stride + id];
float4 bias_data = bias_cast[i * loop_stride + id];
gelu_input_data.x += bias_data.x;
gelu_input_data.y += bias_data.y;
gelu_input_data.z += bias_data.z;
gelu_input_data.w += bias_data.w;
output_data.x *= d_gelu(gelu_input_data.x);
output_data.y *= d_gelu(gelu_input_data.y);
output_data.z *= d_gelu(gelu_input_data.z);
output_data.w *= d_gelu(gelu_input_data.w);
d_output_cast[row * row_stride + i * loop_stride + id] = output_data;
}
}
}
__global__ void d_gelu_func(__half* d_output,
const __half* gelu_input,
const __half* bias,
int row_stride,
int iterations)
{
#ifdef HALF_PRECISION_AVAILABLE
int row = blockIdx.x;
int id = threadIdx.x;
int loop_stride = blockDim.x;
float2* d_output_cast = reinterpret_cast<float2*>(d_output);
const float2* gelu_input_cast = reinterpret_cast<const float2*>(gelu_input);
const float2* bias_cast = reinterpret_cast<const float2*>(bias);
#pragma unroll
for (int i = 0; i < iterations; i++) {
if (i * loop_stride + id < row_stride) {
float2 output_data = d_output_cast[row * row_stride + i * loop_stride + id];
float2 gelu_input_data = gelu_input_cast[row * row_stride + i * loop_stride + id];
float2 bias_vec = bias_cast[i * loop_stride + id];
__half2* output_data_half = reinterpret_cast<__half2*>(&output_data);
__half2* gelu_input_data_half = reinterpret_cast<__half2*>(&gelu_input_data);
__half2* bias_half = reinterpret_cast<__half2*>(&bias_vec);
float2 output_half_0 = __half22float2(output_data_half[0]);
float2 output_half_1 = __half22float2(output_data_half[1]);
float2 gelu_input_half_0 = __half22float2(gelu_input_data_half[0]);
float2 gelu_input_half_1 = __half22float2(gelu_input_data_half[1]);
float2 bias_half_0 = __half22float2(bias_half[0]);
float2 bias_half_1 = __half22float2(bias_half[1]);
gelu_input_half_0.x += bias_half_0.x;
gelu_input_half_0.y += bias_half_0.y;
gelu_input_half_1.x += bias_half_1.x;
gelu_input_half_1.y += bias_half_1.y;
output_half_0.x *= d_gelu(gelu_input_half_0.x);
output_half_0.y *= d_gelu(gelu_input_half_0.y);
output_half_1.x *= d_gelu(gelu_input_half_1.x);
output_half_1.y *= d_gelu(gelu_input_half_1.y);
float2 result;
__half2* result_half2 = reinterpret_cast<__half2*>(&result);
result_half2[0] = __float22half2_rn(output_half_0);
result_half2[1] = __float22half2_rn(output_half_1);
d_output_cast[row * row_stride + i * loop_stride + id] = result;
}
}
#endif
}
template <typename T>
void launch_bias_gelu(const T* input,
const T* bias,
T* output,
int intermediate_size,
int batch_size,
hipStream_t stream)
{
int iterations = (intermediate_size + 1023) / 1024;
int threads = (intermediate_size - 1) / (iterations * 4) + 1;
dim3 block_dims(threads);
dim3 grid_dims(batch_size);
hipLaunchKernelGGL(( fused_bias_gelu), dim3(grid_dims), dim3(block_dims), 0, stream,
input, bias, output, intermediate_size / 4, iterations);
}
template <typename T>
void launch_gelu(const T* input,
T* output,
int intermediate_size,
int batch_size,
hipStream_t stream)
{
int iterations = (intermediate_size + 1023) / 1024;
int threads = (intermediate_size - 1) / (iterations * 4) + 1;
dim3 block_dims(threads);
dim3 grid_dims(batch_size);
hipLaunchKernelGGL(( gelu_kernel), dim3(grid_dims), dim3(block_dims), 0, stream,
input, output, intermediate_size / 4, iterations);
}
template void launch_bias_gelu<float>(const float*, const float*, float*, int, int, hipStream_t);
template void launch_bias_gelu<__half>(const __half*,
const __half*,
__half*,
int,
int,
hipStream_t);
template void launch_gelu<float>(const float*, float*, int, int, hipStream_t);
template void launch_gelu<__half>(const __half*, __half*, int, int, hipStream_t);
template <typename T>
void launch_d_gelu(T* d_output,
const T* input,
const T* bias,
int intermediate_size,
int batch_size,
hipStream_t stream)
{
int iterations = (intermediate_size + 1023) / 1024;
int threads = (intermediate_size - 1) / (iterations * 4) + 1;
dim3 block_dims(threads);
dim3 grid_dims(batch_size);
hipLaunchKernelGGL(( d_gelu_func), dim3(grid_dims), dim3(block_dims), 0, stream,
d_output, input, bias, intermediate_size / 4, iterations);
}
template void launch_d_gelu<float>(float*, const float*, const float*, int, int, hipStream_t);
template void launch_d_gelu<__half>(__half*, const __half*, const __half*, int, int, hipStream_t);
csrc/transformer/general_kernels.hip
deleted
100644 → 0
View file @
7dd68788
// !!! This is a file automatically generated by hipify!!!
#include "hip/hip_runtime.h"
#include "general_kernels_hip.h"
namespace cg = cooperative_groups;
template <typename T>
__global__ void column_sum_reduce(const T* __restrict__ inp,
T* __restrict__ out,
int rows,
int width)
{
__shared__ float tile[TILE_DIM][TILE_DIM + 1];
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<TILE_DIM> g = cg::tiled_partition<TILE_DIM>(b);
int idx = blockDim.x * blockIdx.x + threadIdx.x;
int y_stride = width * TILE_DIM;
float localSum = 0;
// Loop across matrix height
if (idx < width) {
int offset = threadIdx.y * width + idx;
for (int r = threadIdx.y; r < rows; r += TILE_DIM) {
localSum += (float)inp[offset];
offset += y_stride;
}
}
tile[threadIdx.x][threadIdx.y] = localSum;
__syncthreads();
// Sum the shared buffer.
float sum = tile[threadIdx.y][threadIdx.x];
#ifndef __STOCHASTIC_MODE__
__syncthreads();
#endif
for (int i = 1; i < TILE_DIM; i <<= 1) sum += g.shfl_down(sum, i);
if (threadIdx.x == 0) {
int pos = blockIdx.x * TILE_DIM + threadIdx.y;
if (pos < width) out[pos] = sum;
}
}
template <typename T>
void launch_fuse_transpose_bias_kernel(const T* inp,
T* out,
int rows,
int cols,
hipStream_t stream);
template <>
void launch_fuse_transpose_bias_kernel<float>(const float* inp,
float* out,
int rows,
int cols,
hipStream_t stream)
{
// assert(rows % TILE_DIM == 0);
// assert(cols % TILE_DIM == 0);
dim3 grid_dim((cols - 1) / TILE_DIM + 1);
dim3 block_dim(TILE_DIM, TILE_DIM);
hipLaunchKernelGGL(( column_sum_reduce<float>), dim3(grid_dim), dim3(block_dim), 0, stream, inp, out, rows, cols);
}
template <>
void launch_fuse_transpose_bias_kernel<__half>(const __half* inp,
__half* out,
int rows,
int cols,
hipStream_t stream)
{
// assert(rows % TILE_DIM == 0);
// assert(cols % TILE_DIM == 0);
dim3 grid_dim((cols - 1) / TILE_DIM + 1);
dim3 block_dim(TILE_DIM, TILE_DIM);
hipLaunchKernelGGL(( column_sum_reduce<__half>), dim3(grid_dim), dim3(block_dim), 0, stream, inp, out, rows, cols);
}
__global__ void fused_add2_kernel(const int N, float* out, const float* inp1, const float* inp2)
{
const float4* inp1_4 = reinterpret_cast<const float4*>(inp1);
const float4* inp2_4 = reinterpret_cast<const float4*>(inp2);
float4* out_4 = reinterpret_cast<float4*>(out);
CUDA_1D_KERNEL_LOOP(j, N)
{
float4 val;
float4 inp1_reg = inp1_4[j];
float4 inp2_reg = inp2_4[j];
val.x = inp1_reg.x + inp2_reg.x;
val.y = inp1_reg.y + inp2_reg.y;
val.z = inp1_reg.z + inp2_reg.z;
val.w = inp1_reg.w + inp2_reg.w;
out_4[j] = val;
}
}
__global__ void fused_add2_kernel(const int N, __half* out, const __half* inp1, const __half* inp2)
{
float2 inp1_4;
float2 inp2_4;
__half2* inp1_h = reinterpret_cast<__half2*>(&inp1_4);
__half2* inp2_h = reinterpret_cast<__half2*>(&inp2_4);
const float2* inp1_arr = reinterpret_cast<const float2*>(inp1);
const float2* inp2_arr = reinterpret_cast<const float2*>(inp2);
CUDA_1D_KERNEL_LOOP(j, N)
{
inp1_4 = inp1_arr[j];
inp2_4 = inp2_arr[j];
float2 inp1_h_f_0 = __half22float2(inp1_h[0]);
float2 inp1_h_f_1 = __half22float2(inp1_h[1]);
float2 inp2_h_f_0 = __half22float2(inp2_h[0]);
float2 inp2_h_f_1 = __half22float2(inp2_h[1]);
inp1_h_f_0.x += inp2_h_f_0.x;
inp1_h_f_0.y += inp2_h_f_0.y;
inp1_h_f_1.x += inp2_h_f_1.x;
inp1_h_f_1.y += inp2_h_f_1.y;
float2 val_f;
__half2* val_h = reinterpret_cast<__half2*>(&val_f);
val_h[0] = __float22half2_rn(inp1_h_f_0);
val_h[1] = __float22half2_rn(inp1_h_f_1);
float2* out_4 = reinterpret_cast<float2*>(out);
out_4[j] = val_f;
}
}
template <>
void launch_fused_add2<float>(float* out,
const float* inp1,
const float* inp2,
int batch_size,
int seq_length,
int hidden_dim,
hipStream_t& stream)
{
int total_count = batch_size * seq_length * hidden_dim / 4;
dim3 grid_dim = DS_GET_BLOCKS(total_count); //(batch_size * seq_length);
dim3 block_dim = DS_CUDA_NUM_THREADS; //(hidden_dim / 4);
hipLaunchKernelGGL(( fused_add2_kernel), dim3(grid_dim), dim3(block_dim), 0, stream, total_count, out, inp1, inp2);
}
template <>
void launch_fused_add2<__half>(__half* out,
const __half* inp1,
const __half* inp2,
int batch_size,
int seq_length,
int hidden_dim,
hipStream_t& stream)
{
int total_count = batch_size * seq_length * hidden_dim / 4;
dim3 grid_dim = DS_GET_BLOCKS(total_count); //(batch_size * seq_length);
dim3 block_dim = DS_CUDA_NUM_THREADS; //(hidden_dim / 4);
hipLaunchKernelGGL(( fused_add2_kernel), dim3(grid_dim), dim3(block_dim), 0, stream, total_count, out, inp1, inp2);
}
__global__ void fused_add3_kernel(float* out,
const float* inp1,
const float* inp2,
const float* inp3,
int size,
int row_stride)
{
int row = blockIdx.x;
int id = threadIdx.x;
const float4* inp1_4 = reinterpret_cast<const float4*>(inp1);
const float4* inp2_4 = reinterpret_cast<const float4*>(inp2);
const float4* inp3_4 = reinterpret_cast<const float4*>(inp3);
float4* out_4 = reinterpret_cast<float4*>(out);
float4 val;
float4 inp1_reg = inp1_4[row * row_stride + id];
float4 inp2_reg = inp2_4[row * row_stride + id];
float4 inp3_reg = inp3_4[row * row_stride + id];
val.x = inp1_reg.x + inp2_reg.x + inp3_reg.x;
val.y = inp1_reg.y + inp2_reg.y + inp3_reg.y;
val.z = inp1_reg.z + inp2_reg.z + inp3_reg.z;
val.w = inp1_reg.w + inp2_reg.w + inp3_reg.w;
out_4[row * row_stride + id] = val;
}
__global__ void fused_add3_kernel(__half* out,
const __half* inp1,
const __half* inp2,
const __half* inp3,
int size,
int row_stride)
{
int row = blockIdx.x;
int id = threadIdx.x;
const float2* inp1_arr = reinterpret_cast<const float2*>(inp1);
const float2* inp2_arr = reinterpret_cast<const float2*>(inp2);
const float2* inp3_arr = reinterpret_cast<const float2*>(inp3);
float2 inp1_4 = inp1_arr[row * row_stride + id];
float2 inp2_4 = inp2_arr[row * row_stride + id];
float2 inp3_4 = inp3_arr[row * row_stride + id];
__half2* inp1_h = reinterpret_cast<__half2*>(&inp1_4);
__half2* inp2_h = reinterpret_cast<__half2*>(&inp2_4);
__half2* inp3_h = reinterpret_cast<__half2*>(&inp3_4);
float2 inp1_h_f_0 = __half22float2(inp1_h[0]);
float2 inp1_h_f_1 = __half22float2(inp1_h[1]);
float2 inp2_h_f_0 = __half22float2(inp2_h[0]);
float2 inp2_h_f_1 = __half22float2(inp2_h[1]);
float2 inp3_h_f_0 = __half22float2(inp3_h[0]);
float2 inp3_h_f_1 = __half22float2(inp3_h[1]);
inp1_h_f_0.x += (inp2_h_f_0.x + inp3_h_f_0.x);
inp1_h_f_0.y += (inp2_h_f_0.y + inp3_h_f_0.y);
inp1_h_f_1.x += (inp2_h_f_1.x + inp3_h_f_1.x);
inp1_h_f_1.y += (inp2_h_f_1.y + inp3_h_f_1.y);
float2 val_f;
__half2* val_h = reinterpret_cast<__half2*>(&val_f);
val_h[0] = __float22half2_rn(inp1_h_f_0);
val_h[1] = __float22half2_rn(inp1_h_f_1);
float2* out_4 = reinterpret_cast<float2*>(out);
out_4[row * row_stride + id] = val_f;
}
template <>
void launch_fused_add3<float>(float* out,
const float* inp1,
const float* inp2,
const float* inp3,
int batch_size,
int seq_length,
int hidden_size,
hipStream_t& stream)
{
dim3 grid_dim(batch_size * seq_length);
dim3 block_dim(hidden_size / 4);
hipLaunchKernelGGL(( fused_add3_kernel), dim3(grid_dim), dim3(block_dim), 0, stream,
out, inp1, inp2, inp3, (batch_size * seq_length * hidden_size), hidden_size / 4);
}
template <>
void launch_fused_add3<__half>(__half* out,
const __half* inp1,
const __half* inp2,
const __half* inp3,
int batch_size,
int seq_length,
int hidden_size,
hipStream_t& stream)
{
dim3 grid_dim(batch_size * seq_length);
dim3 block_dim(hidden_size / 4);
hipLaunchKernelGGL(( fused_add3_kernel), dim3(grid_dim), dim3(block_dim), 0, stream,
out, inp1, inp2, inp3, (batch_size * seq_length * hidden_size), hidden_size / 4);
}
__global__ void fused_add4_kernel(float* out,
const float* inp1,
const float* inp2,
const float* inp3,
const float* inp4,
int size,
int row_stride)
{
int row = blockIdx.x;
int id = threadIdx.x;
const float4* inp1_4 = reinterpret_cast<const float4*>(inp1);
const float4* inp2_4 = reinterpret_cast<const float4*>(inp2);
const float4* inp3_4 = reinterpret_cast<const float4*>(inp3);
const float4* inp4_4 = reinterpret_cast<const float4*>(inp4);
float4* out_4 = reinterpret_cast<float4*>(out);
float4 val;
float4 inp1_reg = inp1_4[row * row_stride + id];
float4 inp2_reg = inp2_4[row * row_stride + id];
float4 inp3_reg = inp3_4[row * row_stride + id];
float4 inp4_reg = inp4_4[row * row_stride + id];
val.x = inp1_reg.x + inp2_reg.x + inp3_reg.x + inp4_reg.x;
val.y = inp1_reg.y + inp2_reg.y + inp3_reg.y + inp4_reg.y;
val.z = inp1_reg.z + inp2_reg.z + inp3_reg.z + inp4_reg.z;
val.w = inp1_reg.w + inp2_reg.w + inp3_reg.w + inp4_reg.w;
out_4[row * row_stride + id] = val;
}
__global__ void fused_add4_kernel(__half* out,
const __half* inp1,
const __half* inp2,
const __half* inp3,
const __half* inp4,
int size,
int row_stride)
{
int row = blockIdx.x;
int id = threadIdx.x;
const float2* inp1_arr = reinterpret_cast<const float2*>(inp1);
const float2* inp2_arr = reinterpret_cast<const float2*>(inp2);
const float2* inp3_arr = reinterpret_cast<const float2*>(inp3);
const float2* inp4_arr = reinterpret_cast<const float2*>(inp4);
float2 inp1_4 = inp1_arr[row * row_stride + id];
float2 inp2_4 = inp2_arr[row * row_stride + id];
float2 inp3_4 = inp3_arr[row * row_stride + id];
float2 inp4_4 = inp4_arr[row * row_stride + id];
__half2* inp1_h = reinterpret_cast<__half2*>(&inp1_4);
__half2* inp2_h = reinterpret_cast<__half2*>(&inp2_4);
__half2* inp3_h = reinterpret_cast<__half2*>(&inp3_4);
__half2* inp4_h = reinterpret_cast<__half2*>(&inp4_4);
float2 inp1_h_f_0 = __half22float2(inp1_h[0]);
float2 inp1_h_f_1 = __half22float2(inp1_h[1]);
float2 inp2_h_f_0 = __half22float2(inp2_h[0]);
float2 inp2_h_f_1 = __half22float2(inp2_h[1]);
float2 inp3_h_f_0 = __half22float2(inp3_h[0]);
float2 inp3_h_f_1 = __half22float2(inp3_h[1]);
float2 inp4_h_f_0 = __half22float2(inp4_h[0]);
float2 inp4_h_f_1 = __half22float2(inp4_h[1]);
inp1_h_f_0.x += (inp2_h_f_0.x + inp3_h_f_0.x + inp4_h_f_0.x);
inp1_h_f_0.y += (inp2_h_f_0.y + inp3_h_f_0.y + inp4_h_f_0.y);
inp1_h_f_1.x += (inp2_h_f_1.x + inp3_h_f_1.x + inp4_h_f_1.x);
inp1_h_f_1.y += (inp2_h_f_1.y + inp3_h_f_1.y + inp4_h_f_1.y);
float2 val_f;
__half2* val_h = reinterpret_cast<__half2*>(&val_f);
val_h[0] = __float22half2_rn(inp1_h_f_0);
val_h[1] = __float22half2_rn(inp1_h_f_1);
float2* out_4 = reinterpret_cast<float2*>(out);
out_4[row * row_stride + id] = val_f;
}
template <>
void launch_fused_add4<float>(float* out,
const float* inp1,
const float* inp2,
const float* inp3,
const float* inp4,
int batch_size,
int seq_length,
int hidden_size,
hipStream_t& stream)
{
dim3 grid_dim(batch_size * seq_length);
dim3 block_dim(hidden_size / 4);
hipLaunchKernelGGL(( fused_add4_kernel), dim3(grid_dim), dim3(block_dim), 0, stream,
out, inp1, inp2, inp3, inp4, (batch_size * seq_length * hidden_size), hidden_size / 4);
}
template <>
void launch_fused_add4<__half>(__half* out,
const __half* inp1,
const __half* inp2,
const __half* inp3,
const __half* inp4,
int batch_size,
int seq_length,
int hidden_size,
hipStream_t& stream)
{
dim3 grid_dim(batch_size * seq_length);
dim3 block_dim(hidden_size / 4);
hipLaunchKernelGGL(( fused_add4_kernel), dim3(grid_dim), dim3(block_dim), 0, stream,
out, inp1, inp2, inp3, inp4, (batch_size * seq_length * hidden_size), hidden_size / 4);
}
csrc/transformer/inference/csrc/apply_rotary_pos_emb.hip
deleted
100644 → 0
View file @
7dd68788
// !!! This is a file automatically generated by hipify!!!
#include "hip/hip_runtime.h"
/*
Copyright 2022 The Microsoft DeepSpeed Team
*/
#include "inference_cuda_layers.h"
#ifndef __HIP_PLATFORM_HCC__
#include <cuda_profiler_api.h>
#endif
namespace cg = cooperative_groups;
namespace cg = cooperative_groups;
__global__ void apply_rotary_pos_emb(float* mixed_query,
float* key_layer,
unsigned rotary_dim,
unsigned seq_len,
unsigned seq_offset,
unsigned num_heads,
unsigned head_size,
unsigned total_count,
int max_out_tokens)
{
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<WARP_SIZE> g = cg::tiled_partition<WARP_SIZE>(b);
int id = threadIdx.x;
int gid = id >> 5;
int lane = id & 0x1f;
unsigned head_id = blockIdx.x * MAX_WARP_NUM + gid;
unsigned offset = head_id * head_size;
unsigned seq_id = (head_id / num_heads) % seq_len + seq_offset;
unsigned seq_index = head_id % seq_len;
unsigned k_offset = (seq_index + (head_id / seq_len) * max_out_tokens) * head_size;
if (head_id < total_count) {
while (lane < rotary_dim) {
float inv_freq = (float)((lane / 2) * 2) / (float)rotary_dim;
inv_freq = 1.0 / powf(10000.0, inv_freq) * (float)seq_id;
float q = mixed_query[offset + lane];
float k = key_layer[k_offset + lane];
float rotary_sign = (lane % 2 == 1 ? -1.0 : 1.0);
float q_rot = (q * rotary_sign);
float k_rot = (k * rotary_sign);
q_rot = g.shfl_xor(q_rot, 1);
k_rot = g.shfl_xor(k_rot, 1);
q = q * cosf(inv_freq) + q_rot * sinf(inv_freq);
k = k * cosf(inv_freq) + k_rot * sinf(inv_freq);
mixed_query[offset + lane] = q;
key_layer[k_offset + lane] = k;
lane += WARP_SIZE;
}
}
}
__global__ void apply_rotary_pos_emb(__half* mixed_query,
__half* key_layer,
unsigned rotary_dim,
unsigned seq_len,
unsigned seq_offset,
unsigned num_heads,
unsigned head_size,
unsigned total_count,
int max_out_tokens)
{
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<WARP_SIZE> g = cg::tiled_partition<WARP_SIZE>(b);
int id = threadIdx.x;
int gid = id >> 5;
int lane = id & 0x1f;
unsigned head_id = blockIdx.x * MAX_WARP_NUM + gid;
unsigned offset = head_id * head_size;
unsigned seq_id = (head_id / num_heads) % seq_len + seq_offset;
unsigned seq_index = head_id % seq_len;
unsigned k_offset = (seq_index + (head_id / seq_len) * max_out_tokens) * head_size;
if (head_id < total_count) {
while (lane < rotary_dim) {
float inv_freq = (float)((lane / 2) * 2) / (float)rotary_dim;
inv_freq = 1.0 / powf(10000.0, inv_freq) * (float)seq_id;
float q = (float)mixed_query[offset + lane];
float k = (float)key_layer[k_offset + lane];
float rotary_sign = (lane % 2 == 1 ? -1.0 : 1.0);
float q_rot = (q * rotary_sign);
float k_rot = (k * rotary_sign);
q_rot = g.shfl_xor(q_rot, 1);
k_rot = g.shfl_xor(k_rot, 1);
q = q * cosf(inv_freq) + q_rot * sinf(inv_freq);
k = k * cosf(inv_freq) + k_rot * sinf(inv_freq);
mixed_query[offset + lane] = (__half)q;
key_layer[k_offset + lane] = (__half)k;
lane += WARP_SIZE;
}
}
}
__global__ void apply_rotary_pos_emb1(float* mixed_query,
float* key_layer,
unsigned rotary_dim,
unsigned seq_len,
unsigned seq_offset,
unsigned num_heads,
unsigned head_size,
unsigned total_count,
int max_out_tokens)
{
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<WARP_SIZE> g = cg::tiled_partition<WARP_SIZE>(b);
int id = threadIdx.x;
int gid = id >> 5;
int lane = id & 0x1f;
unsigned head_id = blockIdx.x * MAX_WARP_NUM + gid;
unsigned offset = head_id * head_size;
unsigned seq_id = (head_id / num_heads) % seq_len + seq_offset;
unsigned seq_index = head_id % seq_len;
unsigned k_offset = (seq_index + (head_id / seq_len) * max_out_tokens) * head_size;
if (head_id < total_count) {
while (lane < rotary_dim) {
float inv_freq = (float)((lane / 2) * 2) / (float)rotary_dim;
inv_freq = 1.0 / powf(10000.0, inv_freq) * (float)seq_id;
float q = mixed_query[offset + lane];
float k = key_layer[k_offset + lane];
float rotary_sign = (lane % 2 == 1 ? -1.0 : 1.0);
float q_rot = (q * rotary_sign);
float k_rot = (k * rotary_sign);
q_rot = g.shfl_xor(q_rot, 1);
k_rot = g.shfl_xor(k_rot, 1);
q = q * cosf(inv_freq) + q_rot * sinf(inv_freq);
k = k * cosf(inv_freq) + k_rot * sinf(inv_freq);
mixed_query[offset + lane] = q;
key_layer[k_offset + lane] = k;
lane += WARP_SIZE;
}
}
}
__global__ void apply_rotary_pos_emb1(__half* mixed_query,
__half* key_layer,
unsigned rotary_dim,
unsigned seq_len,
unsigned seq_offset,
unsigned num_heads,
unsigned head_size,
unsigned total_count,
int max_out_tokens)
{
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<WARP_SIZE> g = cg::tiled_partition<WARP_SIZE>(b);
int id = threadIdx.x;
int gid = id >> 5;
int lane = id & 0x1f;
unsigned head_id = blockIdx.x * MAX_WARP_NUM + gid;
unsigned seq_index = head_id % seq_len;
unsigned offset = head_id * head_size;
unsigned k_offset = (seq_index + (head_id / seq_len) * max_out_tokens) * head_size;
constexpr unsigned mask[32] = {
0x1 | 0x1000, 0x2 | 0x2000, 0x4 | 0x4000, 0x8 | 0x8000, 0x10 | 0x10000,
0x20 | 0x20000, 0x40 | 0x40000, 0x80 | 0x80000, 0x100 | 0x100000, 0x200 | 0x200000,
0x400 | 0x400000, 0x800 | 0x800000, 0x1000 | 0x1, 0x2000 | 0x2, 0x4000 | 0x4,
0x8000 | 0x8, 0x10000 | 0x10, 0x20000 | 0x20, 0x40000 | 0x40, 0x80000 | 0x80,
0x100000 | 0x100, 0x200000 | 0x200, 0x400000 | 0x400, 0x800000 | 0x800, 0x1000000,
0x2000000, 0x4000000, 0x8000000, 0x10000000, 0x20000000,
0x40000000, 0x80000000};
unsigned seq_id = (head_id % seq_len) + seq_offset;
unsigned half_dim = rotary_dim >> 1;
if (head_id < total_count) {
while (lane < rotary_dim) {
float inv_freq = (float)((lane % half_dim) * 2) / (float)rotary_dim;
inv_freq = 1.0 / powf(10000.0, inv_freq) * (float)seq_id;
float q = (float)mixed_query[offset + lane];
float k = (float)key_layer[k_offset + lane];
float rotary_sign = (lane > (half_dim - 1) ? -1.0 : 1.0);
float q_rot = (q * rotary_sign);
float k_rot = (k * rotary_sign);
auto q_rot_tmp = lane < half_dim ? __shfl_sync(mask[lane], q_rot, lane + half_dim)
: __shfl_sync(mask[lane], q_rot, lane - half_dim);
auto k_rot_tmp = lane < half_dim ? __shfl_sync(mask[lane], k_rot, lane + half_dim)
: __shfl_sync(mask[lane], k_rot, lane - half_dim);
q = q * cosf(inv_freq) + q_rot_tmp * sinf(inv_freq);
k = k * cosf(inv_freq) + k_rot_tmp * sinf(inv_freq);
mixed_query[offset + lane] = (__half)q;
key_layer[k_offset + lane] = (__half)k;
lane += WARP_SIZE;
}
}
}
template <typename T>
void launch_apply_rotary_pos_emb(T* mixed_query,
T* key_layer,
unsigned head_size,
unsigned seq_len,
unsigned rotary_dim,
unsigned offset,
unsigned num_heads,
unsigned batch,
bool rotate_half,
bool rotate_every_two,
hipStream_t stream,
int max_out_tokens)
{
int total_count = batch * num_heads * seq_len;
dim3 block_dims(1024);
dim3 grid_dims((total_count - 1) / MAX_WARP_NUM + 1); // (batch_size);
if (rotate_every_two)
hipLaunchKernelGGL(( apply_rotary_pos_emb), dim3(grid_dims), dim3(block_dims), 0, stream, mixed_query,
key_layer,
rotary_dim,
seq_len,
offset,
num_heads,
head_size,
total_count,
max_out_tokens);
else if (rotate_half)
hipLaunchKernelGGL(( apply_rotary_pos_emb1), dim3(grid_dims), dim3(block_dims), 0, stream, mixed_query,
key_layer,
rotary_dim,
seq_len,
offset,
num_heads,
head_size,
total_count,
max_out_tokens);
}
template void launch_apply_rotary_pos_emb<float>(float*,
float*,
unsigned,
unsigned,
unsigned,
unsigned,
unsigned,
unsigned,
bool,
bool,
hipStream_t,
int);
template void launch_apply_rotary_pos_emb<__half>(__half*,
__half*,
unsigned,
unsigned,
unsigned,
unsigned,
unsigned,
unsigned,
bool,
bool,
hipStream_t,
int);
/*
__global__ void apply_rotary_pos_emb(float* mixed_query,
float* key_layer,
unsigned rotary_dim,
unsigned seq_len,
unsigned seq_offset,
unsigned num_heads,
unsigned head_size,
unsigned total_count)
{
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<WARP_SIZE> g = cg::tiled_partition<WARP_SIZE>(b);
int id = threadIdx.x;
int gid = id >> 5;
int lane = id & 0x1f;
unsigned head_id = blockIdx.x * MAX_WARP_NUM + gid;
unsigned offset = head_id * head_size;
unsigned seq_id = (head_id / num_heads) % seq_len + seq_offset;
if (head_id < total_count) {
while (lane < rotary_dim) {
float inv_freq = (float)((lane / 2) * 2) / (float)rotary_dim;
inv_freq = 1.0 / powf(10000.0, inv_freq) * (float)seq_id;
float q = mixed_query[offset + lane];
float k = key_layer[offset + lane];
float rotary_sign = (lane % 2 == 1 ? -1.0 : 1.0);
float q_rot = (q * rotary_sign);
float k_rot = (k * rotary_sign);
q_rot = g.shfl_xor(q_rot, 1);
k_rot = g.shfl_xor(k_rot, 1);
q = q * cosf(inv_freq) + q_rot * sinf(inv_freq);
k = k * cosf(inv_freq) + k_rot * sinf(inv_freq);
mixed_query[offset + lane] = q;
key_layer[offset + lane] = k;
lane += WARP_SIZE;
}
}
}
__global__ void apply_rotary_pos_emb(__half* mixed_query,
__half* key_layer,
unsigned rotary_dim,
unsigned seq_len,
unsigned seq_offset,
unsigned num_heads,
unsigned head_size,
unsigned total_count)
{
#if __CUDA_ARCH__ >= 700
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<WARP_SIZE> g = cg::tiled_partition<WARP_SIZE>(b);
int id = threadIdx.x;
int gid = id >> 5;
int lane = id & 0x1f;
unsigned head_id = blockIdx.x * MAX_WARP_NUM + gid;
unsigned offset = head_id * head_size;
constexpr unsigned mask[32] = {0x1 | 0x1000, 0x2 | 0x2000, 0x4 | 0x4000, 0x8 | 0x8000,
0x10 | 0x10000, 0x20 | 0x20000, 0x40 | 0x40000, 0x80 | 0x80000,
0x100 | 0x100000, 0x200 | 0x200000, 0x400 | 0x400000, 0x800 | 0x800000,
0x1000 | 0x1, 0x2000 | 0x2, 0x4000 | 0x4, 0x8000 | 0x8,
0x10000 | 0x10, 0x20000 | 0x20, 0x40000 | 0x40, 0x80000 | 0x80,
0x100000 | 0x100, 0x200000 | 0x200, 0x400000 | 0x400, 0x800000 | 0x800,
0x1000000, 0x2000000, 0x4000000, 0x8000000,
0x10000000, 0x20000000, 0x40000000, 0x80000000};
unsigned seq_id = (head_id / num_heads) % seq_len + seq_offset;
if (head_id < total_count) {
while (lane < rotary_dim) {
//float inv_freq = (float)((lane / 2) * 2) / (float)rotary_dim;
float inv_freq = (float)((lane % (rotary_dim >> 1)) * 2) / (float)rotary_dim;
inv_freq = 1.0 / powf(10000.0, inv_freq) * (float)seq_id;
float q = (float)mixed_query[offset + lane];
float k = (float)key_layer[offset + lane];
float rotary_sign = (lane > 11 ? -1.0 : 1.0);
float q_rot = (q * rotary_sign);
float k_rot = (k * rotary_sign);
auto q_rot_tmp = lane < 12 ? __shfl_sync(mask[lane], q_rot, lane + 12) : __shfl_sync(mask[lane],
q_rot, lane - 12);//g.shfl_xor(q_rot, 12); auto k_rot_tmp = lane < 12 ? __shfl_sync(mask[lane],
k_rot, lane + 12) : __shfl_sync(mask[lane], k_rot, lane - 12);//g.shfl_xor(k_rot, 12); q = q *
cosf(inv_freq) + q_rot_tmp * sinf(inv_freq); k = k * cosf(inv_freq) + k_rot_tmp * sinf(inv_freq);
mixed_query[offset + lane] = (__half)q;
key_layer[offset + lane] = (__half)k;
lane += WARP_SIZE;
}
}
#endif
}
template <typename T>
void launch_apply_rotary_pos_emb(T* mixed_query,
T* key_layer,
unsigned head_size,
unsigned seq_len,
unsigned rotary_dim,
unsigned offset,
unsigned num_heads,
unsigned batch,
hipStream_t stream)
{
int total_count = batch * num_heads * seq_len;
dim3 block_dims(1024);
dim3 grid_dims((total_count - 1) / MAX_WARP_NUM + 1); // (batch_size);
hipLaunchKernelGGL((
apply_rotary_pos_emb), dim3(grid_dims), dim3(block_dims), 0, stream,
mixed_query, key_layer, rotary_dim, seq_len, offset, num_heads, head_size, total_count);
}
template void launch_apply_rotary_pos_emb<float>(float*,
float*,
unsigned,
unsigned,
unsigned,
unsigned,
unsigned,
unsigned,
hipStream_t);
template void launch_apply_rotary_pos_emb<__half>(__half*,
__half*,
unsigned,
unsigned,
unsigned,
unsigned,
unsigned,
unsigned,
hipStream_t);
*/
csrc/transformer/inference/csrc/dequantize.hip
deleted
100644 → 0
View file @
7dd68788
// !!! This is a file automatically generated by hipify!!!
#include "hip/hip_runtime.h"
/*
Copyright 2022 The Microsoft DeepSpeed Team
*/
#include "inference_cuda_layers.h"
#define MAX_QUANTIZE_GROUPING 1024
#define loop_unroll 1
#define loop_unroll_bits 1
__global__ void dequantize_kernel(float* output,
const int8_t* input,
const float* qscale,
int output_size,
int hidden_dim,
int groups,
int merge_count)
{
unsigned merge_hidden = hidden_dim >> merge_count;
unsigned quantization_stride = (merge_hidden * output_size) / groups;
unsigned bid = blockIdx.x;
unsigned tid = threadIdx.x;
while (tid < output_size) {
unsigned w_index = bid / merge_hidden;
unsigned q_index = tid + bid * output_size;
auto q = input[q_index];
unsigned merge_hidden_total = w_index * merge_hidden;
unsigned scale_index =
((((bid - merge_hidden_total) + tid * merge_hidden) / quantization_stride)
<< merge_count) +
w_index;
float scale_data = qscale[scale_index];
output[q_index] = (scale_data * (float)q);
tid += blockDim.x;
}
}
__global__ void dequantize_kernel(__half* output,
const int8_t* input,
const float* qscale,
unsigned output_size,
unsigned hidden_dim,
unsigned groups,
unsigned merge_count)
{
unsigned merge_hidden = hidden_dim >> merge_count;
unsigned quantization_stride = (merge_hidden * output_size) / groups;
unsigned bid = blockIdx.x;
unsigned tid = threadIdx.x;
while (tid < output_size) {
unsigned w_index = bid / merge_hidden;
unsigned q_index = tid + bid * output_size;
auto q = input[q_index];
unsigned merge_hidden_total = w_index * merge_hidden;
unsigned scale_index =
((((bid - merge_hidden_total) + tid * merge_hidden) / quantization_stride)
<< merge_count) +
w_index;
float scale_data = qscale[scale_index];
output[q_index] = __float2half(scale_data * (float)q);
tid += blockDim.x;
}
}
template <typename T>
void launch_dequantize(T* output,
const int8_t* input,
const float* qscale,
unsigned output_size,
unsigned hidden_dim,
unsigned groups,
unsigned merge_count,
hipStream_t stream)
{
unsigned threads = 1024;
dim3 block_dims(threads);
dim3 grid_dims(hidden_dim);
hipLaunchKernelGGL(( dequantize_kernel), dim3(grid_dims), dim3(block_dims), 0, stream,
output, input, qscale, output_size, hidden_dim, groups, merge_count);
}
template void launch_dequantize<float>(float*,
const int8_t*,
const float*,
unsigned,
unsigned,
unsigned,
unsigned,
hipStream_t);
template void launch_dequantize<__half>(__half*,
const int8_t*,
const float*,
unsigned,
unsigned,
unsigned,
unsigned,
hipStream_t);
__global__ void dequantize_kernel(float* output,
const int8_t* input,
const float* qscale,
int hidden_dim,
unsigned merge_hidden,
int cnt)
{
}
__global__ void dequantize_kernel(__half* output,
const int8_t* input,
const float* qscale,
unsigned hidden_dim,
unsigned merge_hidden,
int cnt)
{
unsigned bid = blockIdx.x * gridDim.y + blockIdx.y;
unsigned tid = threadIdx.x;
float local_scale = qscale[blockIdx.x];
const float* input_cast = reinterpret_cast<const float*>(input);
float2* output_cast = reinterpret_cast<float2*>(output);
input_cast += bid * merge_hidden;
output_cast += bid * merge_hidden;
for (int c = 0; c < cnt; c++) {
if (tid < merge_hidden) {
float q = input_cast[tid];
int8_t* q_int8 = (int8_t*)&q;
float2 q_f;
__half* q_h = (__half*)&q_f;
q_h[0] = __float2half(local_scale * (float)q_int8[0]);
q_h[1] = __float2half(local_scale * (float)q_int8[1]);
q_h[2] = __float2half(local_scale * (float)q_int8[2]);
q_h[3] = __float2half(local_scale * (float)q_int8[3]);
output_cast[tid] = q_f;
tid += blockDim.x;
}
}
}
template <typename T>
void launch_dequantize(T* output,
const int8_t* input,
const float* qscale,
unsigned output_size,
unsigned hidden_dim,
unsigned groups,
hipStream_t stream)
{
unsigned threads = 1024;
hidden_dim /= 4;
unsigned hid_cnt = threads / hidden_dim;
unsigned thd_cnt = (hidden_dim - 1) / threads + 1;
hid_cnt = hid_cnt > 0 ? hid_cnt : 1;
unsigned blocks = (output_size + hid_cnt * groups - 1) / (hid_cnt * groups);
dim3 block_dims(threads);
dim3 grid_dims(groups, blocks);
hipLaunchKernelGGL(( dequantize_kernel), dim3(grid_dims), dim3(block_dims), 0, stream,
output, input, qscale, hidden_dim, hid_cnt * hidden_dim, thd_cnt);
}
template void launch_dequantize<float>(float*,
const int8_t*,
const float*,
unsigned,
unsigned,
unsigned,
hipStream_t);
template void launch_dequantize<__half>(__half*,
const int8_t*,
const float*,
unsigned,
unsigned,
unsigned,
hipStream_t);
csrc/transformer/inference/csrc/gelu.hip
deleted
100644 → 0
View file @
7dd68788
// !!! This is a file automatically generated by hipify!!!
#include "hip/hip_runtime.h"
/*
Copyright 2022 The Microsoft DeepSpeed Team
*/
#include "conversion_utils.h"
#include "inference_cuda_layers.h"
#include "memory_access_utils.h"
namespace cg = cooperative_groups;
#define MAX_CAP 4
#define MAX_SEQ 2048
inline __device__ float gelu(const float x)
{
const float sqrt_param = 0.79788456080286535587989211986876f;
const float mul_param = 0.044715;
return x * 0.5f * (1.0f + tanhf(sqrt_param * (x + mul_param * x * x * x)));
}
/*
In-place gelu(biasAdd(x)) for channels last
*/
template <typename T>
__global__ void fused_bias_gelu(T* input, const T* bias, int total_count, int intermediate_size)
{
// Input restriction: intermediate_size % vals_per_access == 0
constexpr int granularity = 16;
constexpr int values_per_access = granularity / sizeof(T);
const int offset = (blockIdx.x * blockDim.x + threadIdx.x) * values_per_access;
if (offset < total_count) {
T data[values_per_access];
T data_bias[values_per_access];
mem_access::load_global<granularity>(data, input + offset);
mem_access::load_global<granularity>(data_bias, bias + (offset % intermediate_size));
#pragma unroll
for (int i = 0; i < values_per_access; i++) {
float data_f = conversion::to<float>(data[i]);
float bias_f = conversion::to<float>(data_bias[i]);
data[i] = conversion::to<T>(gelu(data_f + bias_f));
}
mem_access::store_global<granularity>(input + offset, data);
}
}
template <typename T>
void launch_bias_gelu(T* input,
const T* bias,
int intermediate_size,
int batch_size,
hipStream_t stream)
{
constexpr int threads = 1024;
constexpr int granularity = 16;
const int total_count = batch_size * intermediate_size;
const int elems_per_block = threads * (granularity / sizeof(T));
dim3 block_dims(threads);
dim3 grid_dims((total_count + elems_per_block - 1) / elems_per_block);
hipLaunchKernelGGL(( fused_bias_gelu), dim3(grid_dims), dim3(block_dims), 0, stream,
input, bias, total_count, intermediate_size);
}
template void launch_bias_gelu<float>(float*, const float*, int, int, hipStream_t);
template void launch_bias_gelu<__half>(__half*, const __half*, int, int, hipStream_t);
/*
In-place channels-last bias add
*/
template <typename T>
__global__ void fused_bias_add(T* input, const T* bias, int total_count, int intermediate_size)
{
// Input restriction: intermediate_size % vals_per_access == 0
constexpr int granularity = 16;
constexpr int values_per_access = granularity / sizeof(T);
const int offset = (blockIdx.x * blockDim.x + threadIdx.x) * values_per_access;
if (offset < total_count) {
T data[values_per_access];
T data_bias[values_per_access];
mem_access::load_global<granularity>(data, input + offset);
mem_access::load_global<granularity>(data_bias, bias + (offset % intermediate_size));
#pragma unroll
for (int i = 0; i < values_per_access; i++) {
float data_f = conversion::to<float>(data[i]);
float bias_f = conversion::to<float>(data_bias[i]);
data[i] = conversion::to<T>(data_f + bias_f);
}
mem_access::store_global<granularity>(input + offset, data);
}
}
template <typename T>
void launch_bias_add(T* input,
const T* bias,
int intermediate_size,
int batch_size,
hipStream_t stream)
{
constexpr int threads = 1024;
constexpr int granularity = 16;
const int total_count = batch_size * intermediate_size;
const int elems_per_block = threads * (granularity / sizeof(T));
dim3 block_dims(threads);
dim3 grid_dims((total_count + elems_per_block - 1) / elems_per_block);
hipLaunchKernelGGL(( fused_bias_add), dim3(grid_dims), dim3(block_dims), 0, stream,
input, bias, total_count, intermediate_size);
}
template void launch_bias_add<float>(float*, const float*, int, int, hipStream_t);
template void launch_bias_add<__half>(__half*, const __half*, int, int, hipStream_t);
__global__ void fused_bias_residual(float* residual,
const float* hidden_state,
const float* attn,
const float* bias,
const float* attn_bias,
const int total_count,
const int intermediate_size,
const float mp_scale,
const bool preln)
{
float4* res_fl4_ptr = reinterpret_cast<float4*>(residual);
const float4* hs_fl4_ptr = reinterpret_cast<const float4*>(hidden_state);
const float4* attn_fl4_ptr = reinterpret_cast<const float4*>(attn);
const float4* bias_fl4_ptr = reinterpret_cast<const float4*>(bias);
const float4* attn_bias_fl4_ptr = reinterpret_cast<const float4*>(attn_bias);
const int offset = blockIdx.x * blockDim.x + threadIdx.x;
if (offset < total_count) {
float4 res_fl4 = res_fl4_ptr[offset];
const float4 hs_fl4 = hs_fl4_ptr[offset];
const float4 attn_fl4 = attn_fl4_ptr[offset];
const float4 bias_fl4 = bias_fl4_ptr[offset % intermediate_size];
const float4 attn_bias_fl4 = attn_bias_fl4_ptr[offset % intermediate_size];
if (preln) {
// residual = (residual + attention + bias + attention_bias) *
// mp_scale + hidden_state
res_fl4.x =
(res_fl4.x + attn_fl4.x + bias_fl4.x + attn_bias_fl4.x) * mp_scale + (hs_fl4.x);
res_fl4.y =
(res_fl4.y + attn_fl4.y + bias_fl4.y + attn_bias_fl4.y) * mp_scale + (hs_fl4.y);
res_fl4.z =
(res_fl4.z + attn_fl4.z + bias_fl4.z + attn_bias_fl4.z) * mp_scale + (hs_fl4.z);
res_fl4.w =
(res_fl4.w + attn_fl4.w + bias_fl4.w + attn_bias_fl4.w) * mp_scale + (hs_fl4.w);
} else {
// residual += hidden_state + bias
res_fl4.x = res_fl4.x + hs_fl4.x + bias_fl4.x;
res_fl4.y = res_fl4.y + hs_fl4.y + bias_fl4.y;
res_fl4.z = res_fl4.z + hs_fl4.z + bias_fl4.z;
res_fl4.w = res_fl4.w + hs_fl4.w + bias_fl4.w;
}
res_fl4_ptr[offset] = res_fl4;
}
}
__global__ void fused_bias_residual(__half* residual,
const __half* hidden_state,
const __half* attn,
const __half* bias,
const __half* attn_bias,
const int total_count,
const int intermediate_size,
const float mp_scale,
const bool preln)
{
float2* res_fl2_ptr = reinterpret_cast<float2*>(residual);
const float2* hs_fl2_ptr = reinterpret_cast<const float2*>(hidden_state);
const float2* attn_fl2_ptr = reinterpret_cast<const float2*>(attn);
const float2* bias_fl2_ptr = reinterpret_cast<const float2*>(bias);
const float2* attn_bias_fl2_ptr = reinterpret_cast<const float2*>(attn_bias);
const int offset = blockIdx.x * blockDim.x + threadIdx.x;
if (offset < total_count) {
float2 res_fl2 = res_fl2_ptr[offset];
const float2 hs_fl2 = hs_fl2_ptr[offset];
const float2 attn_fl2 = attn_fl2_ptr[offset];
const float2 bias_fl2 = bias_fl2_ptr[offset % intermediate_size];
const float2 attn_bias_fl2 = attn_bias_fl2_ptr[offset % intermediate_size];
__half2* res_half2 = reinterpret_cast<__half2*>(&res_fl2);
const __half2* hs_half2 = reinterpret_cast<const __half2*>(&hs_fl2);
const __half2* attn_half2 = reinterpret_cast<const __half2*>(&attn_fl2);
const __half2* bias_half2 = reinterpret_cast<const __half2*>(&bias_fl2);
const __half2* attn_bias_half2 = reinterpret_cast<const __half2*>(&attn_bias_fl2);
float2 res_low = __half22float2(res_half2[0]);
float2 res_high = __half22float2(res_half2[1]);
const float2 hs_low = __half22float2(hs_half2[0]);
const float2 hs_high = __half22float2(hs_half2[1]);
const float2 attn_low = __half22float2(attn_half2[0]);
const float2 attn_high = __half22float2(attn_half2[1]);
const float2 bias_low = __half22float2(bias_half2[0]);
const float2 bias_high = __half22float2(bias_half2[1]);
const float2 attn_bias_low = __half22float2(attn_bias_half2[0]);
const float2 attn_bias_high = __half22float2(attn_bias_half2[1]);
if (preln) {
// residual = (residual + attention + bias + attention_bias) *
// mp_scale + hidden_state
res_low.x =
(res_low.x + attn_low.x + bias_low.x + attn_bias_low.x) * mp_scale + hs_low.x;
res_low.y =
(res_low.y + attn_low.y + bias_low.y + attn_bias_low.y) * mp_scale + hs_low.y;
res_high.x =
(res_high.x + attn_high.x + bias_high.x + attn_bias_high.x) * mp_scale + hs_high.x;
res_high.y =
(res_high.y + attn_high.y + bias_high.y + attn_bias_high.y) * mp_scale + hs_high.y;
} else {
// residual += hidden_state + bias
res_low.x = (res_low.x + hs_low.x + bias_low.x);
res_low.y = (res_low.y + hs_low.y + bias_low.y);
res_high.x = (res_high.x + hs_high.x + bias_high.x);
res_high.y = (res_high.y + hs_high.y + bias_high.y);
}
res_half2[0] = __float22half2_rn(res_low);
res_half2[1] = __float22half2_rn(res_high);
res_fl2_ptr[offset] = res_fl2;
}
}
template <typename T>
void launch_bias_residual(T* residual,
T* hidden_state,
T* attn,
T* bias,
T* attn_bias,
int batch,
int hidden_dim,
int mp_size,
bool preln,
hipStream_t stream)
{
int total_count = batch * hidden_dim / 4;
dim3 block_dims(1024);
dim3 grid_dims((total_count - 1) / 1024 + 1); // (batch_size);
hipLaunchKernelGGL(( fused_bias_residual), dim3(grid_dims), dim3(block_dims), 0, stream, residual,
hidden_state,
attn,
bias,
attn_bias,
total_count,
hidden_dim / 4,
1.0 / mp_size,
preln);
}
template void launch_bias_residual<
float>(float*, float*, float*, float*, float*, int, int, int, bool, hipStream_t);
template void launch_bias_residual<
__half>(__half*, __half*, __half*, __half*, __half*, int, int, int, bool, hipStream_t);
__global__ void gptj_residual_add(float* residual,
const float* hidden_state,
const float* attn,
const float* bias,
const float* attn_bias,
const int total_count,
const int intermediate_size,
const float mp_scale)
{
float4* res_fl4_ptr = reinterpret_cast<float4*>(residual);
const float4* hs_fl4_ptr = reinterpret_cast<const float4*>(hidden_state);
const float4* attn_fl4_ptr = reinterpret_cast<const float4*>(attn);
const float4* bias_fl4_ptr = reinterpret_cast<const float4*>(bias);
const float4* attn_bias_fl4_ptr = reinterpret_cast<const float4*>(attn_bias);
const int offset = blockIdx.x * blockDim.x + threadIdx.x;
if (offset < total_count) {
float4 res_fl4 = res_fl4_ptr[offset];
const float4 hs_fl4 = hs_fl4_ptr[offset];
const float4 attn_fl4 = attn_fl4_ptr[offset];
const float4 bias_fl4 = bias_fl4_ptr[offset % intermediate_size];
if (attn_bias) {
float4 attn_bias_fl4 = attn_bias_fl4_ptr[offset % intermediate_size];
// residual += attention_bias
res_fl4.x += attn_bias_fl4.x;
res_fl4.y += attn_bias_fl4.y;
res_fl4.z += attn_bias_fl4.z;
res_fl4.w += attn_bias_fl4.w;
}
// residual = hidden_state + attention + (residual + bias) * mp_scale
res_fl4.x = hs_fl4.x + attn_fl4.x + (res_fl4.x + bias_fl4.x) * mp_scale;
res_fl4.y = hs_fl4.y + attn_fl4.y + (res_fl4.y + bias_fl4.y) * mp_scale;
res_fl4.z = hs_fl4.z + attn_fl4.z + (res_fl4.z + bias_fl4.z) * mp_scale;
res_fl4.w = hs_fl4.w + attn_fl4.w + (res_fl4.w + bias_fl4.w) * mp_scale;
res_fl4_ptr[offset] = res_fl4;
}
}
__global__ void gptj_residual_add(__half* residual,
const __half* hidden_state,
const __half* attn,
const __half* bias,
const __half* attn_bias,
const int total_count,
const int intermediate_size,
const float mp_scale)
{
float2* res_fl2_ptr = reinterpret_cast<float2*>(residual);
const float2* hs_fl2_ptr = reinterpret_cast<const float2*>(hidden_state);
const float2* attn_fl2_ptr = reinterpret_cast<const float2*>(attn);
const float2* bias_fl2_ptr = reinterpret_cast<const float2*>(bias);
const float2* attn_bias_fl2_ptr = reinterpret_cast<const float2*>(attn_bias);
const int offset = blockIdx.x * blockDim.x + threadIdx.x;
if (offset < total_count) {
float2 res_fl2 = res_fl2_ptr[offset];
const float2 hs_fl2 = hs_fl2_ptr[offset];
const float2 attn_fl2 = attn_fl2_ptr[offset];
const float2 bias_fl2 = bias_fl2_ptr[offset % intermediate_size];
__half2* res_half2 = reinterpret_cast<__half2*>(&res_fl2);
const __half2* hs_half2 = reinterpret_cast<const __half2*>(&hs_fl2);
const __half2* attn_half2 = reinterpret_cast<const __half2*>(&attn_fl2);
const __half2* bias_half2 = reinterpret_cast<const __half2*>(&bias_fl2);
float2 res_low = __half22float2(res_half2[0]);
float2 res_high = __half22float2(res_half2[1]);
const float2 hs_low = __half22float2(hs_half2[0]);
const float2 hs_high = __half22float2(hs_half2[1]);
const float2 attn_low = __half22float2(attn_half2[0]);
const float2 attn_high = __half22float2(attn_half2[1]);
const float2 bias_low = __half22float2(bias_half2[0]);
const float2 bias_high = __half22float2(bias_half2[1]);
if (attn_bias) {
const float2 attn_bias_fl2 = attn_bias_fl2_ptr[offset % intermediate_size];
const __half2* attn_bias_half2 = reinterpret_cast<const __half2*>(&attn_bias_fl2);
const float2 attn_bias_low = __half22float2(attn_bias_half2[0]);
const float2 attn_bias_high = __half22float2(attn_bias_half2[1]);
// residual += attention_bias
res_low.x += attn_bias_low.x;
res_low.y += attn_bias_low.y;
res_high.x += attn_bias_high.x;
res_high.y += attn_bias_high.y;
}
// residual = hidden_state + attention + (residual + bias) * mp_scale
res_low.x = attn_low.x + hs_low.x + (res_low.x + bias_low.x) * mp_scale;
res_low.y = attn_low.y + hs_low.y + (res_low.y + bias_low.y) * mp_scale;
res_high.x = attn_high.x + hs_high.x + (res_high.x + bias_high.x) * mp_scale;
res_high.y = attn_high.y + hs_high.y + (res_high.y + bias_high.y) * mp_scale;
res_half2[0] = __float22half2_rn(res_low);
res_half2[1] = __float22half2_rn(res_high);
res_fl2_ptr[offset] = res_fl2;
}
}
template <typename T>
void launch_gptj_residual_add(T* residual,
T* hidden_state,
T* attn,
T* bias,
T* attn_bias,
int hidden_dim,
int batch,
int mp_size,
hipStream_t stream)
{
int total_count = batch * hidden_dim / 4;
dim3 block_dims(1024);
dim3 grid_dims((total_count - 1) / 1024 + 1); // (batch_size);
hipLaunchKernelGGL(( gptj_residual_add), dim3(grid_dims), dim3(block_dims), 0, stream,
residual, hidden_state, attn, bias, attn_bias, total_count, hidden_dim / 4, 1.0 / mp_size);
}
template void launch_gptj_residual_add<float>(float*,
float*,
float*,
float*,
float*,
int,
int,
int,
hipStream_t);
template void launch_gptj_residual_add<__half>(__half*,
__half*,
__half*,
__half*,
__half*,
int,
int,
int,
hipStream_t);
template <typename T>
__global__ void moe_res_matmul(T* residual, T* coef, T* mlp_out, int seq_len, int hidden_dim)
{
constexpr int granularity = 16;
constexpr int vals_per_access = granularity / sizeof(T);
T* residual_seq = residual + blockIdx.x * hidden_dim;
T* mlp_out_seq = mlp_out + blockIdx.x * hidden_dim;
for (unsigned tid = threadIdx.x * vals_per_access; tid < hidden_dim;
tid += blockDim.x * vals_per_access) {
T mlp[vals_per_access];
T res[vals_per_access];
T coef1[vals_per_access];
T coef2[vals_per_access];
mem_access::load_global<granularity>(mlp, mlp_out_seq + tid);
mem_access::load_global<granularity>(res, residual_seq + tid);
mem_access::load_global<granularity>(coef1, coef + tid);
mem_access::load_global<granularity>(coef2, coef + tid + hidden_dim);
#pragma unroll
for (int idx = 0; idx < vals_per_access; idx++) {
mlp[idx] = mlp[idx] * coef2[idx] + res[idx] * coef1[idx];
}
mem_access::store_global<granularity>(mlp_out_seq + tid, mlp);
}
}
template <typename T>
void launch_moe_res_matmul(T* residual,
T* coef,
T* mlp_out,
int seq_len,
int hidden_dim,
hipStream_t stream)
{
dim3 grid_dim(seq_len);
dim3 block_dim(1024);
hipLaunchKernelGGL(( moe_res_matmul), dim3(grid_dim), dim3(block_dim), 0, stream,
residual, coef, mlp_out, seq_len, hidden_dim);
}
template void launch_moe_res_matmul(float* residual,
float* coef,
float* mlp_out,
int seq_len,
int hidden_dim,
hipStream_t stream);
template void launch_moe_res_matmul(__half* residual,
__half* coef,
__half* mlp_out,
int seq_len,
int hidden_dim,
hipStream_t stream);
__global__ void pad_data_kernel(__half* padded_output,
__half* output,
int head_size,
int padded_head_size)
{
float4* padded_output_cast = reinterpret_cast<float4*>(padded_output);
float4* output_cast = reinterpret_cast<float4*>(output);
int bid = blockIdx.x * (blockDim.y) + threadIdx.y;
int idx = threadIdx.x;
padded_output_cast += (bid * padded_head_size);
output_cast += (bid * head_size);
float4 ZERO;
const __half2 zero_h = __float2half2_rn(0.f);
__half2* ZERO_h = reinterpret_cast<__half2*>(&ZERO);
#pragma unroll
for (int i = 0; i < 4; i++) ZERO_h[i] = zero_h;
if (idx < head_size)
padded_output_cast[idx] = output_cast[idx];
else
padded_output_cast[idx] = ZERO;
}
__global__ void pad_data_kernel(float* padded_output,
float* output,
int head_size,
int padded_head_size)
{
}
template <typename T>
void pad_data(T* padded_output,
T* output,
int bsz,
int head_size,
int padded_head_size,
hipStream_t stream)
{
dim3 grid_dim((bsz - 1) / 16 + 1);
dim3 block_dim(padded_head_size / 8, 16);
hipLaunchKernelGGL(( pad_data_kernel), dim3(grid_dim), dim3(block_dim), 0, stream,
padded_output, output, head_size / 8, padded_head_size / 8);
}
template void pad_data(__half* padded_output,
__half* output,
int bsz,
int head_size,
int padded_head_size,
hipStream_t stream);
template void pad_data(float* padded_output,
float* output,
int bsz,
int head_size,
int padded_head_size,
hipStream_t stream);
__global__ void pad_head_seq_kernel(__half* padded_output,
__half* output,
int seq_len,
int padded_seq_len,
int head_size,
int padded_head_size)
{
float4* padded_output_cast = reinterpret_cast<float4*>(padded_output);
float4* output_cast = reinterpret_cast<float4*>(output);
int bsz = blockIdx.x;
int bid = blockIdx.y * (blockDim.y) + threadIdx.y;
int idx = threadIdx.x;
padded_output_cast += (bsz * padded_seq_len + bid) * padded_head_size;
output_cast += (bsz * seq_len + bid) * head_size;
float4 ZERO;
const __half2 zero_h = __float2half2_rn(0.f);
__half2* ZERO_h = reinterpret_cast<__half2*>(&ZERO);
#pragma unroll
for (int i = 0; i < 4; i++) ZERO_h[i] = zero_h;
if (idx < head_size && bid < seq_len)
padded_output_cast[idx] = output_cast[idx];
else
padded_output_cast[idx] = ZERO;
}
__global__ void pad_head_seq_kernel(float* padded_output,
float* output,
int seq_len,
int padded_seq_len,
int head_size,
int padded_head_size)
{
}
template <typename T>
void pad_head_seq(T* padded_output,
T* output,
int bsz,
int seq_len,
int padded_seq_len,
int head_size,
int padded_head_size,
hipStream_t stream)
{
dim3 grid_dim(bsz, padded_seq_len / 16);
dim3 block_dim(padded_head_size / 8, 16);
hipLaunchKernelGGL(( pad_head_seq_kernel), dim3(grid_dim), dim3(block_dim), 0, stream,
padded_output, output, seq_len, padded_seq_len, head_size / 8, padded_head_size / 8);
}
template void pad_head_seq(__half* padded_output,
__half* output,
int bsz,
int seq_len,
int padded_seq_len,
int head_size,
int padded_head_size,
hipStream_t stream);
template void pad_head_seq(float* padded_output,
float* output,
int bsz,
int seq_len,
int padded_seq_len,
int head_size,
int padded_head_size,
hipStream_t stream);
// TODO(cmikeh2): evaluate different GeLU performance
__device__ __forceinline__ float old_gelu(float val)
{
// 1 / sqrt(2)
constexpr float rsqrt_2 = 0.707106769084930419922;
return val * 0.5f * (1.0f + erff(val * rsqrt_2));
}
namespace fused_geglu {
constexpr int threads = 256;
constexpr int steps = 2;
constexpr int granularity = 16;
} // namespace fused_geglu
template <typename T>
__global__ void fused_bias_geglu(T* output,
const T* activation,
const T* bias,
int base_channels,
int total_elems)
{
constexpr int T_per_access = fused_geglu::granularity / sizeof(T);
constexpr int T_per_step = T_per_access * fused_geglu::threads;
constexpr int T_per_block = T_per_step * fused_geglu::steps;
const int id = blockIdx.x * T_per_block + threadIdx.x * T_per_access;
#pragma unroll
for (int i = 0; i < fused_geglu::steps; i++) {
T activation_buffer_1[T_per_access];
T activation_buffer_2[T_per_access];
T bias_buffer_1[T_per_access];
T bias_buffer_2[T_per_access];
const int iter_id = id + T_per_step * i;
if (iter_id < total_elems) {
const int channel_id = iter_id % base_channels;
const int seq_id = iter_id / base_channels;
const int seq_offset = seq_id * base_channels * 2;
mem_access::load_global<fused_geglu::granularity>(activation_buffer_1,
activation + seq_offset + channel_id);
mem_access::load_global<fused_geglu::granularity>(
activation_buffer_2, activation + seq_offset + channel_id + base_channels);
mem_access::load_global<fused_geglu::granularity>(bias_buffer_1, bias + channel_id);
mem_access::load_global<fused_geglu::granularity>(bias_buffer_2,
bias + channel_id + base_channels);
// Since the GeLU is going to happen at float, might as well
// convert
#pragma unroll
for (int v = 0; v < T_per_access; v++) {
T hidden_state = activation_buffer_1[v] + bias_buffer_1[v];
T pre_gate = activation_buffer_2[v] + bias_buffer_2[v];
float gate_f = old_gelu(conversion::to<float>(pre_gate));
T gate = conversion::to<T>(gate_f);
activation_buffer_1[v] = hidden_state * gate;
}
mem_access::store_global<fused_geglu::granularity>(output + iter_id,
activation_buffer_1);
}
}
}
template <typename T>
void launch_fused_bias_geglu(T* output,
const T* activation,
const T* bias,
int rows,
int elems_per_row,
hipStream_t stream)
{
/*
Fused bias GEGLU is a variant of the gated activation functions.
The input here is a matrix of [batch, seq_len, 2 * intermediate_dim]
where the second half of the channels act as GeLU gates for the first
half.
*/
// Re-derive the above figures
constexpr int T_per_access = fused_geglu::granularity / sizeof(T);
constexpr int T_per_step = T_per_access * fused_geglu::threads;
constexpr int T_per_block = T_per_step * fused_geglu::steps;
const int base_channels = elems_per_row / 2;
const int total_elems = base_channels * rows;
dim3 block(fused_geglu::threads);
dim3 grid((total_elems + T_per_block - 1) / T_per_block);
hipLaunchKernelGGL(( fused_bias_geglu), dim3(grid), dim3(block), 0, stream,
output, activation, bias, base_channels, total_elems);
}
template void launch_fused_bias_geglu(__half*,
const __half*,
const __half*,
int,
int,
hipStream_t);
template void launch_fused_bias_geglu(float*, const float*, const float*, int, int, hipStream_t);
csrc/transformer/inference/csrc/layer_norm.hip
deleted
100644 → 0
View file @
7dd68788
// !!! This is a file automatically generated by hipify!!!
#include "hip/hip_runtime.h"
/*
Copyright 2022 The Microsoft DeepSpeed Team
*/
#include "conversion_utils.h"
#include "ds_kernel_utils.h"
#include "inference_cuda_layers.h"
#include "memory_access_utils.h"
#include "reduction_utils.h"
namespace cg = cooperative_groups;
using rop = reduce::ROpType;
namespace ln {
constexpr int granularity = 16;
} // namespace ln
/*
Primary layer norm implementation. Assumes elems_per_row % 8
is equal to 0.
Args:
output: buffer for output data
vals: buffer for input data
gamma: gain for normalization
beta: bias for normalization
epsilon: numeric stability
elems_per_row: number of elements each block will normalize
*/
template <typename T, int unRoll, int threadsPerGroup, int maxThreads>
__global__ void fused_ln(T* output,
const T* vals,
const T* gamma,
const T* beta,
float epsilon,
int elems_per_row)
{
constexpr int T_per_load = ln::granularity / sizeof(T);
cg::thread_block tb = cg::this_thread_block();
cg::thread_block_tile<hw_warp_size> warp = cg::tiled_partition<hw_warp_size>(tb);
// X-dimension of the block
const int block_offset = (tb.group_index().x * (maxThreads / threadsPerGroup) * elems_per_row) +
(tb.thread_index().y * elems_per_row);
const int thread_offset = tb.thread_index().x * T_per_load;
const int base_offset = block_offset + thread_offset;
const int stride = tb.size() * T_per_load;
float sum = reduce::init<rop::Add, float>();
const T* input_base = vals + base_offset;
T local_buffer[unRoll * T_per_load];
#pragma unRoll
for (int i = 0; i < unRoll; i++) {
T* iteration_buffer = local_buffer + i * T_per_load;
T residual_buffer[T_per_load];
T bias_buffer[T_per_load];
mem_access::load_global<ln::granularity>(
iteration_buffer, input_base + i * stride, thread_offset + i * stride < elems_per_row);
#pragma unRoll
for (int j = 0; j < T_per_load; j++) {
float vals_up_cast = conversion::to<float>(iteration_buffer[j]);
sum = reduce::element<rop::Add>(sum, vals_up_cast);
}
}
reduce::partitioned_block<rop::Add, threadsPerGroup>(tb, warp, sum);
const float mean = sum / elems_per_row;
float mean_diff = reduce::init<rop::Add, float>();
#pragma unRoll
for (int i = 0; i < unRoll; i++) {
#pragma unRoll
for (int j = 0; j < T_per_load; j++) {
// Using a 0 value here skews the variance, have to if-guard
if (thread_offset + i * stride < elems_per_row) {
float diff = (conversion::to<float>(local_buffer[i * T_per_load + j]) - mean);
mean_diff = reduce::element<rop::Add>(mean_diff, diff * diff);
}
}
}
reduce::partitioned_block<rop::Add, threadsPerGroup>(tb, warp, mean_diff);
const float variance = mean_diff / elems_per_row;
const float denom = __frsqrt_rn(variance + epsilon);
const T mean_compute = conversion::to<T>(mean);
const T denom_compute = conversion::to<T>(denom);
T* block_output = output + block_offset;
#pragma unRoll
for (int i = 0; i < unRoll; i++) {
T* iteration_buffer = local_buffer + i * T_per_load;
const int iter_idx = i * stride + thread_offset;
const bool do_loads = iter_idx < elems_per_row;
T gamma_local[T_per_load], beta_local[T_per_load];
mem_access::load_global<ln::granularity>(gamma_local, gamma + iter_idx, do_loads);
mem_access::load_global<ln::granularity>(beta_local, beta + iter_idx, do_loads);
#pragma unRoll
for (int j = 0; j < T_per_load; j++) {
iteration_buffer[j] = (iteration_buffer[j] - mean_compute) * denom_compute;
iteration_buffer[j] = iteration_buffer[j] * gamma_local[j] + beta_local[j];
}
if (do_loads) {
mem_access::store_global<ln::granularity>(block_output + iter_idx, iteration_buffer);
}
}
}
#define LAUNCH_FUSED_LN(unRollFactor, threadsPerGroup, maxThreads) \
hipLaunchKernelGGL(( fused_ln<T, unRollFactor, threadsPerGroup, maxThreads>) \
, dim3(grid), dim3(block), 0, stream, output, vals, gamma, beta, epsilon, elems_per_row);
template <typename T>
void launch_fused_ln(T* output,
const T* vals,
const T* gamma,
const T* beta,
float epsilon,
int rows,
int elems_per_row,
hipStream_t stream)
{
// 8 for __half, 4 for float
constexpr int T_per_load = ln::granularity / sizeof(T);
constexpr int maxThreads = 256;
// For Flaoat, unRoll 4, for __half, unRoll 2
constexpr int internal_unRoll = sizeof(T) == 4 ? 4 : 2;
const bool is_subblock_schedule = (elems_per_row <= 128) ? true : false;
const int h_per_step = is_subblock_schedule ? T_per_load : T_per_load * internal_unRoll;
// Scheduling concern: may be slightly faster for some inputs to assign multiple stages of
// warp-sized blocks rather than stepping up to 64/96 threads
const int one_step_threads = next_pow2((elems_per_row + h_per_step - 1) / h_per_step);
const int threadsPerGroup = (one_step_threads < maxThreads) ? one_step_threads : maxThreads;
const int groups_per_block_max =
is_subblock_schedule ? (maxThreads + threadsPerGroup - 1) / threadsPerGroup : 1;
const int groups_per_block = (rows < groups_per_block_max) ? rows : groups_per_block_max;
const int groups_launch = (groups_per_block + rows - 1) / groups_per_block;
dim3 block(threadsPerGroup, groups_per_block);
dim3 grid(groups_launch);
const int elems_per_step = threadsPerGroup * h_per_step;
const int external_unRoll = (elems_per_row + elems_per_step - 1) / elems_per_step;
if (is_subblock_schedule) {
// <=128
if (threadsPerGroup == 1) {
LAUNCH_FUSED_LN(1, 1, maxThreads);
} else if (threadsPerGroup == 2) {
LAUNCH_FUSED_LN(1, 2, maxThreads);
} else if (threadsPerGroup == 4) {
LAUNCH_FUSED_LN(1, 4, maxThreads);
} else if (threadsPerGroup == 8) {
LAUNCH_FUSED_LN(1, 8, maxThreads);
} else if (threadsPerGroup == 16) {
LAUNCH_FUSED_LN(1, 16, maxThreads);
}
} else if (external_unRoll == 1) {
// 129 - 4096 elems
// (this can launch with 1-7 warps as well)
LAUNCH_FUSED_LN(1 * internal_unRoll, maxThreads, maxThreads);
} else if (external_unRoll == 2) {
// 4097 - 8192 elems
LAUNCH_FUSED_LN(2 * internal_unRoll, maxThreads, maxThreads);
} else if (external_unRoll == 3) {
// 8193 - 12288 elems
LAUNCH_FUSED_LN(3 * internal_unRoll, maxThreads, maxThreads);
} else if (external_unRoll == 4) {
// 12289 - 16384 elems
LAUNCH_FUSED_LN(4 * internal_unRoll, maxThreads, maxThreads);
}
}
template void launch_fused_ln(__half*,
const __half*,
const __half*,
const __half*,
float,
int,
int,
hipStream_t);
template void
launch_fused_ln(float*, const float*, const float*, const float*, float, int, int, hipStream_t);
/*
Fused resiual + bias + layer norm implementation. Assumes elems_per_row % 8
is equal to 0.
TODO(cmikeh2): Goal is to deprecate this implementation. The bias + residual
need to be fused into compute-bound producer operations.
Args:
output: buffer for output data
res_output: output of residual addition
vals: buffer for input data
residual: residual data
bias: bias of of input data
gamma: gain for normalization
beta: bias for normalization
epsilon: numeric stability
elems_per_row: number of elements each block will normalize
Template arg:
StoreResidual: controls whether the residual calculation is stored
or not. When set to false, the input `res_output` is unused.
*/
template <typename T, int unRoll, int threadsPerGroup, int maxThreads, bool preLnResidual>
__global__ void fused_residual_ln(T* output,
T* res_output,
const T* vals,
const T* residual,
const T* bias,
const T* gamma,
const T* beta,
float epsilon,
int elems_per_row)
{
constexpr int T_per_load = ln::granularity / sizeof(T);
cg::thread_block tb = cg::this_thread_block();
cg::thread_block_tile<hw_warp_size> warp = cg::tiled_partition<hw_warp_size>(tb);
// X-dimension of the block
const int block_offset = (tb.group_index().x * (maxThreads / threadsPerGroup) * elems_per_row) +
(tb.thread_index().y * elems_per_row);
const int thread_offset = tb.thread_index().x * T_per_load;
const int base_offset = block_offset + thread_offset;
const int stride = tb.size() * T_per_load;
float sum = reduce::init<rop::Add, float>();
const T* input_base = vals + base_offset;
const T* residual_base = residual + base_offset;
const T* bias_base = bias + thread_offset;
T local_buffer[unRoll * T_per_load];
// Unlike a vanilla layernorm, since we're fusing the two adds as well
// an inner unRoll seems to be less valuable. If anything, a double unRoll
// makes the most sense if we find we are having performance issues.
#pragma unRoll
for (int i = 0; i < unRoll; i++) {
T* iteration_buffer = local_buffer + i * T_per_load;
T residual_buffer[T_per_load];
T bias_buffer[T_per_load];
mem_access::load_global<ln::granularity>(
iteration_buffer, input_base + i * stride, thread_offset + i * stride < elems_per_row);
mem_access::load_global<ln::granularity>(residual_buffer,
residual_base + i * stride,
thread_offset + i * stride < elems_per_row);
mem_access::load_global<ln::granularity>(
bias_buffer, bias_base + i * stride, thread_offset + i * stride < elems_per_row);
#pragma unRoll
for (int j = 0; j < T_per_load; j++) {
float vals_up_cast = conversion::to<float>(iteration_buffer[j]);
float res_up_cast = conversion::to<float>(residual_buffer[j]);
float bias_up_cast = conversion::to<float>(bias_buffer[j]);
vals_up_cast += res_up_cast + bias_up_cast;
sum = reduce::element<rop::Add>(sum, vals_up_cast);
iteration_buffer[j] = conversion::to<T>(vals_up_cast);
}
if (preLnResidual && (thread_offset + i * stride < elems_per_row)) {
mem_access::store_global<ln::granularity>(res_output + base_offset + i * stride,
iteration_buffer);
}
}
reduce::partitioned_block<rop::Add, threadsPerGroup>(tb, warp, sum);
const float mean = sum / elems_per_row;
float mean_diff = reduce::init<rop::Add, float>();
#pragma unRoll
for (int i = 0; i < unRoll; i++) {
#pragma unRoll
for (int j = 0; j < T_per_load; j++) {
// Using a 0 value here skews the variance, have to if-guard
if (thread_offset + i * stride < elems_per_row) {
float diff = (conversion::to<float>(local_buffer[i * T_per_load + j]) - mean);
mean_diff = reduce::element<rop::Add>(mean_diff, diff * diff);
}
}
}
reduce::partitioned_block<rop::Add, threadsPerGroup>(tb, warp, mean_diff);
const float variance = mean_diff / elems_per_row;
const float denom = __frsqrt_rn(variance + epsilon);
const T mean_compute = conversion::to<T>(mean);
const T denom_compute = conversion::to<T>(denom);
T* block_output = output + block_offset;
#pragma unRoll
for (int i = 0; i < unRoll; i++) {
T* iteration_buffer = local_buffer + i * T_per_load;
const int iter_idx = i * stride + thread_offset;
const bool do_loads = iter_idx < elems_per_row;
T gamma_local[T_per_load], beta_local[T_per_load];
mem_access::load_global<ln::granularity>(gamma_local, gamma + iter_idx, do_loads);
mem_access::load_global<ln::granularity>(beta_local, beta + iter_idx, do_loads);
#pragma unRoll
for (int j = 0; j < T_per_load; j++) {
iteration_buffer[j] = (iteration_buffer[j] - mean_compute) * denom_compute;
iteration_buffer[j] = iteration_buffer[j] * gamma_local[j] + beta_local[j];
}
if (do_loads) {
mem_access::store_global<ln::granularity>(block_output + iter_idx, iteration_buffer);
}
}
}
// TODO(cmikeh2): There's a bunch of redundancy here that needs to be removed/simplified.
#define LAUNCH_FUSED_RES_LN(unRollFactor, threadsPerGroup, maxThreads) \
hipLaunchKernelGGL(( fused_residual_ln<T, unRollFactor, threadsPerGroup, maxThreads, false>) \
, dim3(grid), dim3(block), 0, stream, \
output, nullptr, vals, residual, bias, gamma, beta, epsilon, elems_per_row);
template <typename T>
void launch_fused_residual_ln(T* output,
const T* vals,
const T* residual,
const T* bias,
const T* gamma,
const T* beta,
float epsilon,
int rows,
int elems_per_row,
hipStream_t stream)
{
// 8 for __half, 4 for float
constexpr int T_per_load = ln::granularity / sizeof(T);
constexpr int maxThreads = 256;
// For Flaoat, unRoll 4, for __half, unRoll 2
constexpr int internal_unRoll = sizeof(T) == 4 ? 4 : 2;
const bool is_subblock_schedule = (elems_per_row <= 128) ? true : false;
const int h_per_step = is_subblock_schedule ? T_per_load : T_per_load * internal_unRoll;
// Scheduling concern: may be slightly faster for some inputs to assign multiple stages of
// warp-sized blocks rather than stepping up to 64/96 threads
const int one_step_threads = next_pow2((elems_per_row + h_per_step - 1) / h_per_step);
const int threadsPerGroup = (one_step_threads < maxThreads) ? one_step_threads : maxThreads;
const int groups_per_block_max =
is_subblock_schedule ? (maxThreads + threadsPerGroup - 1) / threadsPerGroup : 1;
const int groups_per_block = (rows < groups_per_block_max) ? rows : groups_per_block_max;
const int groups_launch = (groups_per_block + rows - 1) / groups_per_block;
dim3 block(threadsPerGroup, groups_per_block);
dim3 grid(groups_launch);
const int elems_per_step = threadsPerGroup * h_per_step;
const int external_unRoll = (elems_per_row + elems_per_step - 1) / elems_per_step;
if (is_subblock_schedule) {
// <=128
if (threadsPerGroup == 1) {
LAUNCH_FUSED_RES_LN(1, 1, maxThreads);
} else if (threadsPerGroup == 2) {
LAUNCH_FUSED_RES_LN(1, 2, maxThreads);
} else if (threadsPerGroup == 4) {
LAUNCH_FUSED_RES_LN(1, 4, maxThreads);
} else if (threadsPerGroup == 8) {
LAUNCH_FUSED_RES_LN(1, 8, maxThreads);
} else if (threadsPerGroup == 16) {
LAUNCH_FUSED_RES_LN(1, 16, maxThreads);
}
} else if (external_unRoll == 1) {
// 129 - 4096 elems
// (this can launch with 1-7 warps as well)
LAUNCH_FUSED_RES_LN(1 * internal_unRoll, maxThreads, maxThreads);
} else if (external_unRoll == 2) {
// 4097 - 8192 elems
LAUNCH_FUSED_RES_LN(2 * internal_unRoll, maxThreads, maxThreads);
} else if (external_unRoll == 3) {
// 8193 - 12288 elems
LAUNCH_FUSED_RES_LN(3 * internal_unRoll, maxThreads, maxThreads);
} else if (external_unRoll == 4) {
// 12289 - 16384 elems
LAUNCH_FUSED_RES_LN(4 * internal_unRoll, maxThreads, maxThreads);
}
}
#define LAUNCH_FUSED_RES_LN_STORE_PRE_LN_RES(unRollFactor, threadsPerGroup, maxThreads) \
hipLaunchKernelGGL(( fused_residual_ln<T, unRollFactor, threadsPerGroup, maxThreads, true>) \
, dim3(grid), dim3(block), 0, stream, \
norm_output, res_output, vals, residual, bias, gamma, beta, epsilon, elems_per_row);
template <typename T>
void launch_fused_residual_ln_store_pre_ln_res(T* norm_output,
T* res_output,
const T* vals,
const T* residual,
const T* bias,
const T* gamma,
const T* beta,
float epsilon,
int rows,
int elems_per_row,
hipStream_t stream)
{
// 8 for __half, 4 for float
constexpr int T_per_load = ln::granularity / sizeof(T);
constexpr int maxThreads = 256;
// For Flaoat, unRoll 4, for __half, unRoll 2
constexpr int internal_unRoll = sizeof(T) == 4 ? 4 : 2;
const bool is_subblock_schedule = (elems_per_row <= 128) ? true : false;
const int h_per_step = is_subblock_schedule ? T_per_load : T_per_load * internal_unRoll;
// Scheduling concern: may be slightly faster for some inputs to assign multiple stages of
// warp-sized blocks rather than stepping up to 64/96 threads
const int one_step_threads = next_pow2((elems_per_row + h_per_step - 1) / h_per_step);
const int threadsPerGroup = (one_step_threads < maxThreads) ? one_step_threads : maxThreads;
const int groups_per_block_max =
is_subblock_schedule ? (maxThreads + threadsPerGroup - 1) / threadsPerGroup : 1;
const int groups_per_block = (rows < groups_per_block_max) ? rows : groups_per_block_max;
const int groups_launch = (groups_per_block + rows - 1) / groups_per_block;
dim3 block(threadsPerGroup, groups_per_block);
dim3 grid(groups_launch);
const int elems_per_step = threadsPerGroup * h_per_step;
const int external_unRoll = (elems_per_row + elems_per_step - 1) / elems_per_step;
if (is_subblock_schedule) {
// <=128
if (threadsPerGroup == 1) {
LAUNCH_FUSED_RES_LN_STORE_PRE_LN_RES(1, 1, maxThreads);
} else if (threadsPerGroup == 2) {
LAUNCH_FUSED_RES_LN_STORE_PRE_LN_RES(1, 2, maxThreads);
} else if (threadsPerGroup == 4) {
LAUNCH_FUSED_RES_LN_STORE_PRE_LN_RES(1, 4, maxThreads);
} else if (threadsPerGroup == 8) {
LAUNCH_FUSED_RES_LN_STORE_PRE_LN_RES(1, 8, maxThreads);
} else if (threadsPerGroup == 16) {
LAUNCH_FUSED_RES_LN_STORE_PRE_LN_RES(1, 16, maxThreads);
}
} else if (external_unRoll == 1) {
// 129 - 4096 elems
// (this can launch with 1-7 warps as well)
LAUNCH_FUSED_RES_LN_STORE_PRE_LN_RES(1 * internal_unRoll, maxThreads, maxThreads);
} else if (external_unRoll == 2) {
// 4097 - 8192 elems
LAUNCH_FUSED_RES_LN_STORE_PRE_LN_RES(2 * internal_unRoll, maxThreads, maxThreads);
} else if (external_unRoll == 3) {
// 8193 - 12288 elems
LAUNCH_FUSED_RES_LN_STORE_PRE_LN_RES(3 * internal_unRoll, maxThreads, maxThreads);
} else if (external_unRoll == 4) {
// 12289 - 16384 elems
LAUNCH_FUSED_RES_LN_STORE_PRE_LN_RES(4 * internal_unRoll, maxThreads, maxThreads);
}
}
// No-store specializations
template void launch_fused_residual_ln(__half*,
const __half*,
const __half*,
const __half*,
const __half*,
const __half*,
float,
int,
int,
hipStream_t);
template void launch_fused_residual_ln(float*,
const float*,
const float*,
const float*,
const float*,
const float*,
float,
int,
int,
hipStream_t);
// Store specializations
template void launch_fused_residual_ln_store_pre_ln_res(__half*,
__half*,
const __half*,
const __half*,
const __half*,
const __half*,
const __half*,
float,
int,
int,
hipStream_t);
template void launch_fused_residual_ln_store_pre_ln_res(float*,
float*,
const float*,
const float*,
const float*,
const float*,
const float*,
float,
int,
int,
hipStream_t);
csrc/transformer/inference/csrc/normalize.hip
deleted
100644 → 0
View file @
7dd68788
// !!! This is a file automatically generated by hipify!!!
#include "hip/hip_runtime.h"
#include <limits>
#include "custom_hip_layers.h"
#ifndef __HIP_PLATFORM_HCC__
#include <cuda_profiler_api.h>
#endif
#include <cstdio>
#include <cstdlib>
#include <ctime>
#define NORM_REG (MAX_REGISTERS)
namespace cg = cooperative_groups;
__global__ void fused_bias_residual_layer_norm(float* output,
const float* vals,
const float* gamma,
const float* beta,
float epsilon,
int row_stride)
{
int iteration_stride = blockDim.x;
int iterations = row_stride / iteration_stride;
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<32> g = cg::tiled_partition<32>(b);
int row = blockIdx.x;
int id = threadIdx.x;
int gid = id >> 5;
int warp_num = iteration_stride >> 5;
float inp_reg[NORM_REG];
int k = 0;
float sum = 0;
int input_id = id;
while (input_id < row_stride) {
inp_reg[k] = vals[input_id + row * row_stride];
sum += inp_reg[k++];
input_id += iteration_stride;
}
for (int i = 1; i < 32; i *= 2) sum += g.shfl_down(sum, i);
__shared__ float shr[MAX_WARP_NUM];
if (g.thread_rank() == 0) shr[gid] = sum;
b.sync();
if (g.thread_rank() < (warp_num)) sum = shr[g.thread_rank()];
b.sync();
for (int i = 1; i < (warp_num); i *= 2) sum += g.shfl_down(sum, i);
sum = g.shfl(sum, 0);
float mean = sum / (row_stride);
sum = 0.f;
for (int f = 0; f < k; f++) {
inp_reg[f] -= mean;
sum += inp_reg[f] * inp_reg[f];
}
for (int i = 1; i < 32; i *= 2) sum += g.shfl_down(sum, i);
if (g.thread_rank() == 0) shr[gid] = sum;
b.sync();
if (g.thread_rank() < (warp_num)) sum = shr[g.thread_rank()];
b.sync();
for (int i = 1; i < (warp_num); i *= 2) sum += g.shfl_down(sum, i);
sum = g.shfl(sum, 0);
sum /= (row_stride);
sum += epsilon;
sum = __frsqrt_rn(sum);
for (int f = 0; f < k; f++) {
int out_id = f * iteration_stride + id;
inp_reg[f] = inp_reg[f] * sum;
inp_reg[f] = inp_reg[f] * gamma[out_id] + beta[out_id];
output[out_id + row * row_stride] = inp_reg[f];
}
}
__global__ void fused_bias_residual_layer_norm(__half* output,
const __half* vals,
const __half* gamma,
const __half* beta,
float epsilon,
int row_stride)
{
#ifdef HALF_PRECISION_AVAILABLE
int iteration_stride = blockDim.x;
int iterations = row_stride / iteration_stride;
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<32> g = cg::tiled_partition<32>(b);
int row = blockIdx.x;
int id = threadIdx.x;
int gid = id >> 5;
int warp_num = iteration_stride >> 5;
__half2 inp_reg[NORM_REG];
const __half2* vals_cast = reinterpret_cast<const __half2*>(vals);
__half2* out_cast = reinterpret_cast<__half2*>(output);
int k = 0;
int input_id = id;
while (input_id < row_stride) {
inp_reg[k++] = vals_cast[input_id + row * row_stride];
input_id += iteration_stride;
}
float sum = 0;
for (int f = k - 1; f >= 0; f--) {
float2 inp_f = __half22float2(inp_reg[f]);
sum += inp_f.x + inp_f.y;
}
for (int i = 1; i < 32; i *= 2) sum += g.shfl_down(sum, i);
__shared__ float shr[MAX_WARP_NUM];
if (g.thread_rank() == 0) shr[gid] = sum;
b.sync();
if (g.thread_rank() < (warp_num)) sum = shr[g.thread_rank()];
b.sync();
for (int i = 1; i < (warp_num); i *= 2) sum += g.shfl_down(sum, i);
sum = g.shfl(sum, 0);
float mean = sum / (row_stride << 1);
sum = 0.f;
for (int f = 0; f < k; f++) {
float2 inp_f = __half22float2(inp_reg[f]);
inp_f.x -= mean;
inp_f.y -= mean;
inp_reg[f] = __float22half2_rn(inp_f);
sum += inp_f.x * inp_f.x;
sum += inp_f.y * inp_f.y;
}
for (int i = 1; i < 32; i *= 2) sum += g.shfl_down(sum, i);
if (g.thread_rank() == 0) shr[gid] = sum;
b.sync();
if (g.thread_rank() < (warp_num)) sum = shr[g.thread_rank()];
b.sync();
for (int i = 1; i < (warp_num); i *= 2) sum += g.shfl_down(sum, i);
sum = g.shfl(sum, 0);
sum /= (row_stride << 1);
sum += epsilon;
sum = __frsqrt_rn(sum);
__half2 variance_h = __float2half2_rn(sum);
const __half2* gamma_cast = reinterpret_cast<const __half2*>(gamma);
const __half2* beta_cast = reinterpret_cast<const __half2*>(beta);
for (int f = 0; f < k; f++) {
int out_id = f * iteration_stride + id;
inp_reg[f] = inp_reg[f] * variance_h;
inp_reg[f] = inp_reg[f] * gamma_cast[out_id] + beta_cast[out_id];
out_cast[out_id + row * row_stride] = inp_reg[f];
}
#endif
}
template <typename T>
void launch_layer_norm(T* out,
T* vals,
const T* gamma,
const T* beta,
float epsilon,
int batch_size,
int hidden_dim,
hipStream_t stream);
template <>
void launch_layer_norm<float>(float* out,
float* vals,
const float* gamma,
const float* beta,
float epsilon,
int batch_size,
int hidden_dim,
hipStream_t stream)
{
constexpr int threads = 1024;
dim3 grid_dim(batch_size);
dim3 block_dim(threads);
hipLaunchKernelGGL(( fused_bias_residual_layer_norm), dim3(grid_dim), dim3(block_dim), 0, stream,
out, vals, gamma, beta, epsilon, hidden_dim);
}
template <>
void launch_layer_norm<__half>(__half* out,
__half* vals,
const __half* gamma,
const __half* beta,
float epsilon,
int batch_size,
int hidden_dim,
hipStream_t stream)
{
constexpr int threads = 1024;
dim3 grid_dim(batch_size);
dim3 block_dim(threads);
hipLaunchKernelGGL(( fused_bias_residual_layer_norm), dim3(grid_dim), dim3(block_dim), 0, stream,
out, vals, gamma, beta, epsilon, hidden_dim / 2);
}
__global__ void fused_residual_layer_norm(float* norm,
float* res_add,
float* vals,
float* residual,
const float* bias,
const float* gamma,
const float* beta,
float epsilon,
int row_stride,
bool preLN,
bool mlp_after_attn)
{
int iteration_stride = blockDim.x;
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<32> g = cg::tiled_partition<32>(b);
int row = blockIdx.x;
int id = threadIdx.x;
int gid = id >> 5;
int warp_num = iteration_stride >> 5;
float inp_reg[NORM_REG];
int k = 0;
int input_id = id;
float sum = 0;
while (input_id < row_stride) {
inp_reg[k] = vals[input_id + row * row_stride];
float res_f = (residual[input_id + row * row_stride]);
float bias_f = (bias[input_id]);
if (mlp_after_attn) inp_reg[k] += res_f + bias_f;
// if (preLN) res_add[input_id + row * row_stride] = inp_reg[k];
sum += inp_reg[k++];
input_id += iteration_stride;
}
for (int i = 1; i < 32; i *= 2) sum += g.shfl_down(sum, i);
__shared__ float shr[MAX_WARP_NUM];
if (g.thread_rank() == 0) shr[gid] = sum;
b.sync();
if (g.thread_rank() < (warp_num)) sum = shr[g.thread_rank()];
b.sync();
for (int i = 1; i < (warp_num); i *= 2) sum += g.shfl_down(sum, i);
sum = g.shfl(sum, 0);
float mean = sum / (row_stride);
sum = 0.f;
for (int f = 0; f < k; f++) {
inp_reg[f] -= mean;
sum += inp_reg[f] * inp_reg[f];
}
for (int i = 1; i < 32; i *= 2) sum += g.shfl_down(sum, i);
if (g.thread_rank() == 0) shr[gid] = sum;
b.sync();
if (g.thread_rank() < (warp_num)) sum = shr[g.thread_rank()];
b.sync();
for (int i = 1; i < (warp_num); i *= 2) sum += g.shfl_down(sum, i);
sum = g.shfl(sum, 0);
sum /= (row_stride);
sum += epsilon;
sum = __frsqrt_rn(sum);
for (int f = 0; f < k; f++) {
int out_id = f * iteration_stride + id;
inp_reg[f] = inp_reg[f] * sum;
inp_reg[f] = inp_reg[f] * gamma[out_id] + beta[out_id];
norm[out_id + row * row_stride] = inp_reg[f];
}
}
__global__ void fused_residual_layer_norm(__half* norm,
__half* res_add,
__half* vals,
__half* residual,
const __half* bias,
const __half* gamma,
const __half* beta,
float epsilon,
int row_stride,
bool preLN,
bool mlp_after_attn)
{
#ifdef HALF_PRECISION_AVAILABLE
int iteration_stride = blockDim.x;
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<32> g = cg::tiled_partition<32>(b);
int row = blockIdx.x;
int id = threadIdx.x;
int gid = id >> 5;
int warp_num = iteration_stride >> 5;
__half2 inp_reg[NORM_REG];
__half2* vals_cast = reinterpret_cast<__half2*>(vals);
__half2* norm_cast = reinterpret_cast<__half2*>(norm);
__half2* res_add_cast = reinterpret_cast<__half2*>(res_add);
__half2* residual_cast = reinterpret_cast<__half2*>(residual);
const __half2* bias_cast = reinterpret_cast<const __half2*>(bias);
int k = 0;
int input_id = id;
float sum = 0;
while (input_id < row_stride) {
inp_reg[k] = vals_cast[input_id + row * row_stride];
float2 inp_f = __half22float2(inp_reg[k]);
float2 res_f = __half22float2(residual_cast[input_id + row * row_stride]);
float2 bias_f = __half22float2(bias_cast[input_id]);
if (mlp_after_attn) {
inp_f.x += res_f.x + bias_f.x;
inp_f.y += res_f.y + bias_f.y;
}
inp_reg[k] = __float22half2_rn(inp_f);
// if (preLN) res_add_cast[input_id + row * row_stride] = __float22half2_rn(res_f);
// //inp_reg[k];
sum += inp_f.x + inp_f.y;
input_id += iteration_stride;
k++;
}
for (int i = 1; i < 32; i *= 2) sum += g.shfl_down(sum, i);
__shared__ float shr[MAX_WARP_NUM];
if (g.thread_rank() == 0) shr[gid] = sum;
b.sync();
if (g.thread_rank() < (warp_num)) sum = shr[g.thread_rank()];
b.sync();
for (int i = 1; i < (warp_num); i *= 2) sum += g.shfl_down(sum, i);
sum = g.shfl(sum, 0);
float mean = sum / (row_stride << 1);
sum = 0.f;
for (int f = 0; f < k; f++) {
float2 inp_f = __half22float2(inp_reg[f]);
inp_f.x -= mean;
inp_f.y -= mean;
inp_reg[f] = __float22half2_rn(inp_f);
sum += inp_f.x * inp_f.x;
sum += inp_f.y * inp_f.y;
}
for (int i = 1; i < 32; i *= 2) sum += g.shfl_down(sum, i);
if (g.thread_rank() == 0) shr[gid] = sum;
b.sync();
if (g.thread_rank() < (warp_num)) sum = shr[g.thread_rank()];
b.sync();
for (int i = 1; i < (warp_num); i *= 2) sum += g.shfl_down(sum, i);
sum = g.shfl(sum, 0);
sum /= (row_stride << 1);
sum += epsilon;
sum = __frsqrt_rn(sum);
__half2 variance_h = __float2half2_rn(sum);
const __half2* gamma_cast = reinterpret_cast<const __half2*>(gamma);
const __half2* beta_cast = reinterpret_cast<const __half2*>(beta);
for (int f = 0; f < k; f++) {
int out_id = f * iteration_stride + id;
inp_reg[f] = inp_reg[f] * variance_h;
inp_reg[f] = inp_reg[f] * gamma_cast[out_id] + beta_cast[out_id];
norm_cast[out_id + row * row_stride] = inp_reg[f];
}
#endif
}
template <typename T>
void launch_residual_layer_norm(T* norm,
T* res_add,
T* vals,
T* residual,
const T* bias,
const T* gamma,
const T* beta,
float epsilon,
int batch_size,
int hidden_dim,
bool preLN,
bool mlp_after_attn,
hipStream_t stream);
template <>
void launch_residual_layer_norm<float>(float* norm,
float* res_add,
float* vals,
float* residual,
const float* bias,
const float* gamma,
const float* beta,
float epsilon,
int batch_size,
int hidden_dim,
bool preLN,
bool mlp_after_attn,
hipStream_t stream)
{
constexpr int threads = 1024;
dim3 grid_dim(batch_size);
dim3 block_dim(threads);
hipLaunchKernelGGL(( fused_residual_layer_norm), dim3(grid_dim), dim3(block_dim), 0, stream, norm,
res_add,
vals,
residual,
bias,
gamma,
beta,
epsilon,
hidden_dim,
preLN,
mlp_after_attn);
}
template <>
void launch_residual_layer_norm<__half>(__half* norm,
__half* res_add,
__half* vals,
__half* residual,
const __half* bias,
const __half* gamma,
const __half* beta,
float epsilon,
int batch_size,
int hidden_dim,
bool preLN,
bool mlp_after_attn,
hipStream_t stream)
{
constexpr int threads = 1024;
dim3 grid_dim(batch_size);
dim3 block_dim(threads);
hipLaunchKernelGGL(( fused_residual_layer_norm), dim3(grid_dim), dim3(block_dim), 0, stream, norm,
res_add,
vals,
residual,
bias,
gamma,
beta,
epsilon,
hidden_dim / 2,
preLN,
mlp_after_attn);
}
csrc/transformer/inference/csrc/pt_binding_hip.cpp
deleted
100644 → 0
View file @
7dd68788
// !!! This is a file automatically generated by hipify!!!
/*
Copyright 2022 The Microsoft DeepSpeed Team
*/
#include <ATen/hip/impl/HIPStreamMasqueradingAsCUDA.h>
#include <torch/extension.h>
#include <stdexcept>
#include <vector>
#include "inference_context.h"
#include "inference_cublas_wrappers.h"
#include "inference_cuda_layers.h"
std
::
array
<
int
,
3
>
gemm_algos
=
std
::
array
<
int
,
3
>
({
99
,
99
,
99
});
// NOTE: This activation function type enum should be always in sync
// with the python counterpart, otherwise the casting from python binding
// will be incorrect.
enum
class
ActivationFuncType
{
UNKNOWN
=
0
,
GELU
=
1
,
ReLU
=
2
};
enum
class
TransformerType
:
uint8_t
{
UNKNOWN
=
0
,
GPTType
=
1
,
BERTType
=
2
};
// NOTE: this is a temporary and dodgy solution to distinguish GPT and BERT style models
// based on the dimensions of the corresponding attention mask.
inline
auto
infer_transformer_type
(
at
::
Tensor
&
attn_mask
)
->
TransformerType
{
auto
attn_mask_num_dims
=
attn_mask
.
sizes
().
size
();
if
(
attn_mask_num_dims
>
2
)
{
return
TransformerType
::
GPTType
;
}
else
if
(
attn_mask_num_dims
==
2
)
{
return
TransformerType
::
BERTType
;
}
else
{
return
TransformerType
::
UNKNOWN
;
}
}
// infer stride of attention mask memory layout based on the model type.
inline
auto
get_attn_mask_stride
(
at
::
Tensor
&
attn_mask
)
->
int
{
auto
trnsfrmr_type
=
infer_transformer_type
(
attn_mask
);
if
(
trnsfrmr_type
==
TransformerType
::
GPTType
)
{
return
attn_mask
.
size
(
2
);
}
else
if
(
trnsfrmr_type
==
TransformerType
::
BERTType
)
{
// Bert style models have always a mask stride of 1.
return
1
;
}
else
if
(
trnsfrmr_type
==
TransformerType
::
UNKNOWN
)
{
return
0
;
}
// this is just to make the compiler happy.
return
0
;
}
template
<
typename
T
>
at
::
Tensor
ds_softmax
(
at
::
Tensor
&
attn_scores
,
at
::
Tensor
&
attn_mask
,
at
::
Tensor
&
alibi
,
bool
triangular
,
bool
recompute
,
bool
local_attention
,
int
window_size
,
bool
async_op
,
float
layer_scale
,
int
head_offset
,
int
mp_size
)
{
auto
attn_scores_c
=
attn_scores
.
contiguous
();
int
bsz
=
attn_scores_c
.
size
(
0
);
int
seq_len
=
attn_scores_c
.
size
(
1
);
int
len
=
attn_scores_c
.
sizes
().
size
();
if
(
len
>
2
)
seq_len
=
attn_scores_c
.
size
(
2
);
int
soft_len
=
attn_scores_c
.
size
(
2
);
if
(
len
>
3
)
soft_len
=
attn_scores_c
.
size
(
3
);
int
heads
=
1
;
if
(
len
>
1
)
heads
=
attn_scores_c
.
size
(
1
);
auto
mask_stride
=
get_attn_mask_stride
(
attn_mask
);
launch_attn_softmax_v2
((
T
*
)
attn_scores_c
.
data_ptr
(),
(
attn_mask
.
sizes
().
size
()
>
1
?
(
T
*
)
attn_mask
.
data_ptr
()
:
nullptr
),
(
alibi
.
sizes
().
size
()
>
1
?
(
T
*
)
alibi
.
data_ptr
()
:
nullptr
),
layer_scale
,
triangular
,
recompute
,
local_attention
,
window_size
,
bsz
,
heads
,
seq_len
,
soft_len
,
head_offset
,
mask_stride
,
mp_size
,
Context
::
Instance
().
GetCurrentStream
(
async_op
));
return
attn_scores_c
;
}
template
<
typename
T
>
void
allocate_workspace
(
unsigned
hidden_dim
,
unsigned
num_heads
,
unsigned
prompt_length
,
unsigned
batch_size
,
unsigned
num_layers
,
unsigned
mp_size
=
1
,
bool
external_cache
=
false
,
unsigned
rank
=
0
,
unsigned
max_out_tokens
=
1024
)
{
Context
::
Instance
().
GenWorkSpace
(
num_layers
,
num_heads
,
batch_size
,
prompt_length
,
hidden_dim
,
mp_size
,
external_cache
,
sizeof
(
T
),
rank
,
max_out_tokens
);
}
template
<
typename
T
>
at
::
Tensor
einsum_sec_sm_ecm
(
at
::
Tensor
&
Q
,
at
::
Tensor
&
W
)
{
auto
options
=
at
::
TensorOptions
()
.
dtype
(
Q
.
options
().
dtype
())
.
layout
(
at
::
kStrided
)
.
device
(
at
::
kCUDA
)
.
requires_grad
(
false
);
T
*
workspace
=
(
T
*
)
Context
::
Instance
().
GetWorkSpace
();
float
alpha
=
1
;
float
gemm_beta
=
0.0
;
/*
// Reallocate memory if we received a new prompt
if (!workspace || input.size(1) != 1) {
allocate_workspace<T>(W.size(1), Context::Instance().GetMaxTokenLenght(), Q.size(0), 1,
head_size); workspace = (T*)Context::Instance().GetWorkSpace();
}
*/
auto
O
=
at
::
from_blob
(
workspace
,
{
Q
.
size
(
1
),
Q
.
size
(
2
),
W
.
size
(
1
)},
options
);
unsigned
m
=
W
.
size
(
1
);
unsigned
n
=
Q
.
size
(
1
)
*
Q
.
size
(
2
);
unsigned
k
=
Q
.
size
(
0
);
cublas_gemm_ex
(
Context
::
Instance
().
GetCublasHandle
(),
rocblas_operation_none
,
rocblas_operation_transpose
,
m
,
n
,
k
,
&
alpha
,
&
gemm_beta
,
(
T
*
)
W
.
data_ptr
(),
(
T
*
)
Q
.
data_ptr
(),
(
T
*
)
O
.
data_ptr
(),
#ifdef __HIP_PLATFORM_HCC__
rocblas_gemm_algo_standard
);
#else
CUBLAS_GEMM_DEFAULT_TENSOR_OP
);
#endif
return
O
;
}
template
<
typename
T
>
void
attention_unfused
(
at
::
Tensor
&
prev_key_cont
,
at
::
Tensor
&
query_cont
,
at
::
Tensor
&
attn_mask
,
at
::
Tensor
&
prev_value_cont
,
at
::
Tensor
&
output
,
int
&
bsz
,
int
&
seq_len
,
int
&
soft_len
,
int
&
heads
,
float
&
norm_factor
,
bool
triangular
,
bool
recompute
,
bool
local_attention
,
int
window_size
)
{
auto
options
=
at
::
TensorOptions
()
.
dtype
(
query_cont
.
options
().
dtype
())
.
layout
(
at
::
kStrided
)
.
device
(
at
::
kCUDA
)
.
requires_grad
(
false
);
float
alpha
=
norm_factor
;
float
gemm_beta
=
0.0
;
auto
attn_score
=
at
::
empty
({
bsz
,
heads
,
seq_len
,
soft_len
},
options
);
int
k
=
prev_value_cont
.
size
(
2
)
/
heads
;
auto
mask_stride
=
get_attn_mask_stride
(
attn_mask
);
rocblas_set_stream
(
Context
::
Instance
().
GetCublasHandle
(),
Context
::
Instance
().
GetCurrentStream
());
cublas_strided_batched_gemm
(
Context
::
Instance
().
GetCublasHandle
(),
soft_len
,
seq_len
,
k
,
&
alpha
,
&
gemm_beta
,
(
T
*
)
prev_key_cont
.
data_ptr
(),
(
T
*
)
query_cont
.
data_ptr
(),
(
T
*
)
attn_score
.
data_ptr
(),
rocblas_operation_none
,
rocblas_operation_none
,
soft_len
*
k
,
seq_len
*
k
,
seq_len
*
soft_len
,
bsz
*
heads
,
#ifdef __HIP_PLATFORM_HCC__
rocblas_gemm_algo_standard
);
#else
CUBLAS_GEMM_DEFAULT_TENSOR_OP
);
#endif
launch_attn_softmax_v2
((
T
*
)
attn_score
.
data_ptr
(),
(
T
*
)(
attn_mask
.
sizes
().
size
()
>
1
?
attn_mask
.
data_ptr
()
:
nullptr
),
(
T
*
)
nullptr
,
1.0
,
triangular
,
recompute
,
local_attention
,
window_size
,
bsz
,
heads
,
seq_len
,
soft_len
,
0
,
mask_stride
,
1
,
Context
::
Instance
().
GetCurrentStream
(
false
));
alpha
=
1.0
;
cublas_strided_batched_gemm
(
Context
::
Instance
().
GetCublasHandle
(),
k
,
seq_len
,
soft_len
,
&
alpha
,
&
gemm_beta
,
(
T
*
)
prev_value_cont
.
data_ptr
(),
(
T
*
)
attn_score
.
data_ptr
(),
(
T
*
)
output
.
data_ptr
(),
rocblas_operation_none
,
rocblas_operation_none
,
soft_len
*
k
,
seq_len
*
soft_len
,
seq_len
*
k
,
bsz
*
heads
,
#ifdef __HIP_PLATFORM_HCC__
rocblas_gemm_algo_standard
);
#else
CUBLAS_GEMM_DEFAULT_TENSOR_OP
);
#endif
}
template
<
typename
T
>
std
::
vector
<
at
::
Tensor
>
ds_softmax_context1
(
at
::
Tensor
&
query
,
at
::
Tensor
&
prev_key
,
at
::
Tensor
&
new_key
,
at
::
Tensor
&
attn_mask
,
at
::
Tensor
&
prev_value
,
at
::
Tensor
&
new_value
,
int
heads
,
float
norm_factor
,
bool
merging
,
bool
triangular
,
bool
local_attention
,
int
window_size
,
bool
no_masking
)
{
auto
query_cont
=
query
.
contiguous
();
auto
prev_key_cont
=
prev_key
.
contiguous
();
auto
prev_value_cont
=
prev_value
.
contiguous
();
int
new_size
=
(
new_value
.
sizes
().
size
()
>
1
?
new_value
.
size
(
1
)
:
0
);
// Attn_Score [ batch Head Sequence-length Softmax-length]
int
bsz
=
query_cont
.
size
(
0
);
int
seq_len
=
query_cont
.
size
(
1
);
int
soft_len
=
prev_value
.
size
(
1
);
auto
options
=
at
::
TensorOptions
()
.
dtype
(
query_cont
.
options
().
dtype
())
.
layout
(
at
::
kStrided
)
.
device
(
at
::
kCUDA
)
.
requires_grad
(
false
);
auto
output
=
at
::
empty
({
prev_value
.
size
(
0
),
heads
,
seq_len
,
prev_value
.
size
(
2
)
/
heads
},
options
);
attention_unfused
<
T
>
(
prev_key_cont
,
query_cont
,
attn_mask
,
//(no_masking ? nullptr : (T*)attn_mask.data_ptr()),
prev_value_cont
,
output
,
bsz
,
seq_len
,
soft_len
,
heads
,
norm_factor
,
(
triangular
&&
(
new_size
==
0
)),
(
new_size
==
0
),
local_attention
,
window_size
);
return
{
output
,
prev_key
,
prev_value
};
}
template
<
typename
T
>
void
ds_softmax_internal
(
T
*
attn_scores
,
at
::
Tensor
&
attn_mask
,
at
::
Tensor
&
alibi
,
float
&
layer_scale
,
bool
triangular
,
bool
recompute
,
bool
local_attention
,
int
window_size
,
int
bsz
,
int
seq_len
,
int
soft_len
,
int
heads
)
{
auto
mask_stride
=
get_attn_mask_stride
(
attn_mask
);
launch_attn_softmax_v2
((
T
*
)
attn_scores
,
(
attn_mask
.
sizes
().
size
()
>
1
?
(
T
*
)
attn_mask
.
data_ptr
()
:
nullptr
),
(
alibi
.
sizes
().
size
()
>
1
?
(
T
*
)
alibi
.
data_ptr
()
:
nullptr
),
layer_scale
,
triangular
,
recompute
,
local_attention
,
window_size
,
bsz
,
heads
,
seq_len
,
soft_len
,
0
,
mask_stride
,
1
,
at
::
hip
::
getCurrentHIPStreamMasqueradingAsCUDA
());
}
template
<
typename
T
>
void
attention_unfused
(
T
*
prev_key_cont
,
T
*
query_cont
,
at
::
Tensor
&
attn_mask
,
T
*
prev_value_cont
,
T
*
output
,
unsigned
&
bsz
,
int
&
k
,
unsigned
&
seq_len
,
unsigned
&
soft_len
,
int
&
heads
,
float
&
norm_factor
,
bool
triangular
,
bool
recompute
,
bool
local_attention
,
int
window_size
,
at
::
Tensor
&
alibi
,
int
layer_id
)
{
float
layer_scale
=
alibi
.
sizes
().
size
()
>
1
?
std
::
max
(
1
,
layer_id
)
:
1.0
;
float
alpha
=
norm_factor
*
norm_factor
/
layer_scale
;
float
gemm_beta
=
0.0
;
T
*
workspace
=
(
T
*
)
Context
::
Instance
().
GetAttentionUnfusedWorkspace
();
rocblas_set_stream
(
Context
::
Instance
().
GetCublasHandle
(),
Context
::
Instance
().
GetCurrentStream
());
cublas_strided_batched_gemm
(
Context
::
Instance
().
GetCublasHandle
(),
soft_len
,
seq_len
,
k
,
&
alpha
,
&
gemm_beta
,
(
T
*
)
prev_key_cont
,
(
T
*
)
query_cont
,
workspace
,
rocblas_operation_transpose
,
rocblas_operation_none
,
Context
::
Instance
().
GetMaxTokenLenght
()
*
k
,
seq_len
*
k
,
seq_len
*
soft_len
,
bsz
*
heads
,
#ifdef __HIP_PLATFORM_HCC__
rocblas_gemm_algo_standard
);
#else
CUBLAS_GEMM_DEFAULT_TENSOR_OP
);
#endif
ds_softmax_internal
<
T
>
(
workspace
,
attn_mask
,
alibi
,
layer_scale
,
triangular
,
recompute
,
local_attention
,
window_size
,
bsz
,
seq_len
,
soft_len
,
heads
);
alpha
=
1.0
;
cublas_strided_batched_gemm
(
Context
::
Instance
().
GetCublasHandle
(),
k
,
seq_len
,
soft_len
,
&
alpha
,
&
gemm_beta
,
(
T
*
)
prev_value_cont
,
workspace
,
(
T
*
)
output
,
rocblas_operation_none
,
rocblas_operation_none
,
Context
::
Instance
().
GetMaxTokenLenght
()
*
k
,
seq_len
*
soft_len
,
seq_len
*
k
,
bsz
*
heads
,
#ifdef __HIP_PLATFORM_HCC__
rocblas_gemm_algo_standard
);
#else
CUBLAS_GEMM_DEFAULT_TENSOR_OP
);
#endif
}
void
reset_cache
()
{
Context
::
Instance
().
reset_tokens
();
}
template
<
typename
T
>
std
::
vector
<
at
::
Tensor
>
ds_softmax_context
(
at
::
Tensor
&
query_key_value
,
at
::
Tensor
&
attn_mask
,
int
rotary_dim
,
bool
rotate_half
,
bool
rotate_every_two
,
int
heads
,
float
norm_factor
,
bool
triangular
,
bool
local_attention
,
int
window_size
,
bool
no_masking
,
unsigned
layer_id
,
unsigned
num_layers
,
at
::
Tensor
&
alibi
)
{
unsigned
bsz
=
query_key_value
.
size
(
0
);
unsigned
seq_len
=
query_key_value
.
size
(
1
);
unsigned
hidden_dim
=
query_key_value
.
size
(
2
)
/
3
;
bool
is_prompt
=
(
seq_len
>
1
);
if
(
is_prompt
)
Context
::
Instance
().
reset_tokens
(
seq_len
);
unsigned
soft_len
=
Context
::
Instance
().
current_tokens
();
int
k
=
hidden_dim
/
heads
;
auto
options
=
at
::
TensorOptions
()
.
dtype
(
query_key_value
.
options
().
dtype
())
.
layout
(
at
::
kStrided
)
.
device
(
at
::
kCUDA
)
.
requires_grad
(
false
);
T
*
workspace
=
(
T
*
)
Context
::
Instance
().
GetWorkSpace
();
size_t
buf_size
=
bsz
*
seq_len
*
hidden_dim
;
auto
output
=
torch
::
from_blob
(
workspace
+
4
*
buf_size
,
{
bsz
,
seq_len
,
hidden_dim
},
options
);
auto
query_cont
=
workspace
+
8
*
buf_size
;
size_t
offset
=
16
*
(
hidden_dim
*
bsz
*
Context
::
Instance
().
GetMaxTokenLenght
())
+
layer_id
*
2
*
bsz
*
Context
::
Instance
().
GetMaxTokenLenght
()
*
hidden_dim
;
unsigned
all_tokens
=
soft_len
;
auto
kv_cache
=
workspace
+
offset
+
(
hidden_dim
/
heads
)
*
(
is_prompt
?
0
:
soft_len
-
1
);
size_t
value_offset
=
bsz
*
Context
::
Instance
().
GetMaxTokenLenght
()
*
hidden_dim
;
T
*
temp_buf
=
(
T
*
)
output
.
data_ptr
()
+
at
::
numel
(
output
);
launch_bias_add_transform_0213
<
T
>
((
T
*
)
query_cont
,
kv_cache
,
kv_cache
+
value_offset
,
(
T
*
)
query_key_value
.
data_ptr
(),
nullptr
,
bsz
,
seq_len
,
(
is_prompt
?
0
:
soft_len
-
1
),
soft_len
,
hidden_dim
,
heads
,
rotary_dim
,
rotate_half
,
rotate_every_two
,
Context
::
Instance
().
GetCurrentStream
(),
3
,
Context
::
Instance
().
GetMaxTokenLenght
());
if
(
rotary_dim
>
0
&&
rotate_half
)
launch_apply_rotary_pos_emb
(
query_cont
,
kv_cache
,
k
,
seq_len
,
rotary_dim
,
(
is_prompt
?
0
:
soft_len
-
1
),
heads
,
bsz
,
rotate_half
,
rotate_every_two
,
Context
::
Instance
().
GetCurrentStream
(),
Context
::
Instance
().
GetMaxTokenLenght
());
attention_unfused
<
T
>
(
workspace
+
offset
,
(
T
*
)
query_cont
,
attn_mask
,
workspace
+
offset
+
value_offset
,
temp_buf
,
bsz
,
k
,
seq_len
,
all_tokens
,
heads
,
norm_factor
,
(
triangular
&&
is_prompt
),
is_prompt
,
local_attention
,
window_size
,
alibi
,
layer_id
);
launch_transform4d_0213
<
T
>
((
T
*
)
output
.
data_ptr
(),
temp_buf
,
bsz
,
heads
,
seq_len
,
output
.
size
(
2
),
Context
::
Instance
().
GetCurrentStream
(
false
),
1
);
if
(
layer_id
==
num_layers
-
1
)
Context
::
Instance
().
advance_tokens
();
auto
prev_key
=
torch
::
from_blob
(
workspace
+
offset
,
{
bsz
,
heads
,
all_tokens
,
k
},
options
);
auto
prev_value
=
torch
::
from_blob
(
workspace
+
offset
+
value_offset
,
{
bsz
,
heads
,
all_tokens
,
k
},
options
);
return
{
output
,
prev_key
,
prev_value
};
}
template
<
typename
T
>
at
::
Tensor
ds_bias_gelu
(
at
::
Tensor
&
input
,
at
::
Tensor
&
bias
)
{
auto
input_cont
=
input
.
contiguous
();
int
bsz
=
input_cont
.
size
(
0
)
*
input_cont
.
size
(
1
);
int
intermediate_size
=
input_cont
.
size
(
2
);
launch_bias_gelu
((
T
*
)
input_cont
.
data_ptr
(),
(
T
*
)
bias
.
data_ptr
(),
intermediate_size
,
bsz
,
Context
::
Instance
().
GetCurrentStream
());
return
input_cont
;
}
at
::
Tensor
ds_bias_geglu
(
at
::
Tensor
&
activation
,
at
::
Tensor
&
bias
)
{
/*
Used in FF of Stable diffusion
*/
const
int
batch_size
=
activation
.
size
(
0
);
const
int
seq_len
=
activation
.
size
(
1
);
const
int
channels
=
activation
.
size
(
2
);
const
int
rows
=
batch_size
*
seq_len
;
// Dimensionality is cut in half
const
int
out_channels
=
channels
/
2
;
auto
output
=
at
::
empty
({
batch_size
,
seq_len
,
out_channels
},
activation
.
options
());
if
(
activation
.
options
().
dtype
()
==
torch
::
kFloat32
)
{
launch_fused_bias_geglu
((
float
*
)
output
.
data_ptr
(),
(
const
float
*
)
activation
.
data_ptr
(),
(
const
float
*
)
bias
.
data_ptr
(),
rows
,
channels
,
Context
::
Instance
().
GetCurrentStream
());
}
else
{
launch_fused_bias_geglu
((
__half
*
)
output
.
data_ptr
(),
(
const
__half
*
)
activation
.
data_ptr
(),
(
const
__half
*
)
bias
.
data_ptr
(),
rows
,
channels
,
Context
::
Instance
().
GetCurrentStream
());
}
return
output
;
}
template
<
typename
T
>
at
::
Tensor
ds_bias_relu
(
at
::
Tensor
&
input
,
at
::
Tensor
&
bias
)
{
auto
input_cont
=
input
.
contiguous
();
int
bsz
=
input_cont
.
size
(
0
)
*
input_cont
.
size
(
1
);
int
intermediate_size
=
input_cont
.
size
(
2
);
launch_bias_relu
((
T
*
)
input_cont
.
data_ptr
(),
(
T
*
)
bias
.
data_ptr
(),
intermediate_size
,
bsz
,
Context
::
Instance
().
GetCurrentStream
());
return
input_cont
;
}
template
<
typename
T
>
at
::
Tensor
ds_bias_add
(
at
::
Tensor
&
input
,
at
::
Tensor
&
bias
)
{
auto
input_cont
=
input
.
contiguous
();
int
bsz
=
input_cont
.
size
(
0
)
*
input_cont
.
size
(
1
);
int
hidden_size
=
input_cont
.
size
(
2
);
launch_bias_add
((
T
*
)
input_cont
.
data_ptr
(),
(
T
*
)
bias
.
data_ptr
(),
hidden_size
,
bsz
,
Context
::
Instance
().
GetCurrentStream
());
return
input_cont
;
}
template
<
typename
T
>
at
::
Tensor
ds_bias_residual
(
at
::
Tensor
&
input
,
at
::
Tensor
&
residual
,
at
::
Tensor
&
bias
)
{
auto
input_cont
=
input
.
contiguous
();
auto
residual_cont
=
residual
.
contiguous
();
int
bsz
=
input_cont
.
size
(
0
)
*
input_cont
.
size
(
1
);
// launch_bias_residual((T*)input_cont.data_ptr(),
// (T*)residual_cont.data_ptr(),
// (T*)bias.data_ptr(),
// bsz,
// input_cont.size(2),
// (bias.size(0) > 1),
// Context::Instance().GetCurrentStream());
return
input_cont
;
}
at
::
Tensor
ds_layer_norm
(
at
::
Tensor
&
input
,
at
::
Tensor
&
gamma
,
at
::
Tensor
&
beta
,
float
epsilon
)
{
const
int
rows
=
input
.
size
(
0
)
*
input
.
size
(
1
);
const
int
elems_per_row
=
input
.
size
(
2
);
auto
output
=
at
::
empty_like
(
input
);
if
(
input
.
options
().
dtype
()
==
torch
::
kFloat16
)
{
launch_fused_ln
((
__half
*
)
output
.
data_ptr
(),
(
const
__half
*
)
input
.
data_ptr
(),
(
const
__half
*
)
gamma
.
data_ptr
(),
(
const
__half
*
)
beta
.
data_ptr
(),
epsilon
,
rows
,
elems_per_row
,
Context
::
Instance
().
GetCurrentStream
());
}
else
{
launch_fused_ln
((
float
*
)
output
.
data_ptr
(),
(
const
float
*
)
input
.
data_ptr
(),
(
const
float
*
)
gamma
.
data_ptr
(),
(
const
float
*
)
beta
.
data_ptr
(),
epsilon
,
rows
,
elems_per_row
,
Context
::
Instance
().
GetCurrentStream
());
}
return
output
;
}
template
<
typename
T
>
void
ds_layer_norm_internal
(
T
*
workspace
,
at
::
Tensor
&
input
,
at
::
Tensor
&
gamma
,
at
::
Tensor
&
beta
,
float
epsilon
)
{
int
bsz
=
input
.
size
(
0
)
*
input
.
size
(
1
);
launch_fused_ln
(
workspace
,
(
const
T
*
)
input
.
data_ptr
(),
(
const
T
*
)
gamma
.
data_ptr
(),
(
const
T
*
)
beta
.
data_ptr
(),
epsilon
,
bsz
,
input
.
size
(
2
),
Context
::
Instance
().
GetCurrentStream
());
}
/* Currently only used in unit testing */
at
::
Tensor
ds_layer_norm_residual
(
at
::
Tensor
&
input
,
at
::
Tensor
&
bias
,
at
::
Tensor
&
residual
,
at
::
Tensor
&
gamma
,
at
::
Tensor
&
beta
,
float
epsilon
)
{
const
int
rows
=
input
.
size
(
0
)
*
input
.
size
(
1
);
const
int
elems_per_row
=
input
.
size
(
2
);
auto
output
=
at
::
empty_like
(
input
);
if
(
input
.
options
().
dtype
()
==
torch
::
kFloat16
)
{
launch_fused_residual_ln
((
__half
*
)
output
.
data_ptr
(),
(
const
__half
*
)
input
.
data_ptr
(),
(
const
__half
*
)
residual
.
data_ptr
(),
(
const
__half
*
)
bias
.
data_ptr
(),
(
const
__half
*
)
gamma
.
data_ptr
(),
(
const
__half
*
)
beta
.
data_ptr
(),
epsilon
,
rows
,
elems_per_row
,
Context
::
Instance
().
GetCurrentStream
());
}
else
{
launch_fused_residual_ln
((
float
*
)
output
.
data_ptr
(),
(
const
float
*
)
input
.
data_ptr
(),
(
const
float
*
)
residual
.
data_ptr
(),
(
const
float
*
)
bias
.
data_ptr
(),
(
const
float
*
)
gamma
.
data_ptr
(),
(
const
float
*
)
beta
.
data_ptr
(),
epsilon
,
rows
,
elems_per_row
,
Context
::
Instance
().
GetCurrentStream
());
}
return
output
;
}
/* Currently only used in unit testing */
std
::
vector
<
at
::
Tensor
>
ds_layer_norm_residual_store_pre_ln_res
(
at
::
Tensor
&
input
,
at
::
Tensor
&
bias
,
at
::
Tensor
&
residual
,
at
::
Tensor
&
gamma
,
at
::
Tensor
&
beta
,
float
epsilon
)
{
const
int
rows
=
input
.
size
(
0
)
*
input
.
size
(
1
);
const
int
elems_per_row
=
input
.
size
(
2
);
auto
norm_output
=
at
::
empty_like
(
input
);
auto
res_output
=
at
::
empty_like
(
input
);
if
(
input
.
options
().
dtype
()
==
torch
::
kFloat16
)
{
launch_fused_residual_ln_store_pre_ln_res
((
__half
*
)
norm_output
.
data_ptr
(),
(
__half
*
)
res_output
.
data_ptr
(),
(
const
__half
*
)
input
.
data_ptr
(),
(
const
__half
*
)
residual
.
data_ptr
(),
(
const
__half
*
)
bias
.
data_ptr
(),
(
const
__half
*
)
gamma
.
data_ptr
(),
(
const
__half
*
)
beta
.
data_ptr
(),
epsilon
,
rows
,
elems_per_row
,
Context
::
Instance
().
GetCurrentStream
());
}
else
{
launch_fused_residual_ln_store_pre_ln_res
((
float
*
)
norm_output
.
data_ptr
(),
(
float
*
)
res_output
.
data_ptr
(),
(
const
float
*
)
input
.
data_ptr
(),
(
const
float
*
)
residual
.
data_ptr
(),
(
const
float
*
)
bias
.
data_ptr
(),
(
const
float
*
)
gamma
.
data_ptr
(),
(
const
float
*
)
beta
.
data_ptr
(),
epsilon
,
rows
,
elems_per_row
,
Context
::
Instance
().
GetCurrentStream
());
}
return
{
norm_output
,
res_output
};
}
template
<
typename
T
>
void
quantized_gemm
(
void
*
output
,
T
*
input
,
at
::
Tensor
&
weight
,
at
::
Tensor
&
qscale
,
int
groups
,
int
bsz
,
int
hidden_size
)
{
// T* weight16 = (T*)Context::Instance().GetWorkSpace() + 12 * hidden_size * bsz;
auto
options
=
at
::
TensorOptions
()
.
dtype
(
at
::
kHalf
)
.
layout
(
at
::
kStrided
)
.
device
(
at
::
kCUDA
)
.
requires_grad
(
false
);
auto
tmp
=
torch
::
empty
(
weight
.
sizes
(),
options
);
T
*
weight16
=
(
T
*
)
tmp
.
data_ptr
();
launch_dequantize
(
weight16
,
(
int8_t
*
)
weight
.
data_ptr
(),
(
float
*
)
qscale
.
data_ptr
(),
weight
.
size
(
0
),
weight
.
size
(
1
),
groups
,
Context
::
Instance
().
GetCurrentStream
());
float
alpha
=
(
T
)
1.0
;
float
gemm_beta
=
(
T
)
0.0
;
cublas_gemm_ex
(
Context
::
Instance
().
GetCublasHandle
(),
rocblas_operation_transpose
,
rocblas_operation_none
,
weight
.
size
(
0
),
bsz
,
weight
.
size
(
1
),
&
alpha
,
&
gemm_beta
,
weight16
,
(
T
*
)
input
,
(
T
*
)
output
,
#ifdef __HIP_PLATFORM_HCC__
rocblas_gemm_algo_standard
);
#else
CUBLAS_GEMM_DEFAULT_TENSOR_OP
);
#endif
}
template
<
typename
T
>
at
::
Tensor
qkv_unfused_cublas
(
at
::
Tensor
&
output
,
at
::
Tensor
&
input
,
at
::
Tensor
&
weight
,
at
::
Tensor
&
q_scale
,
at
::
Tensor
&
bias
,
at
::
Tensor
&
gamma
,
at
::
Tensor
&
beta
,
const
float
epsilon
,
bool
add_bias
,
bool
q_int8
)
{
int
bsz
=
input
.
size
(
0
)
*
input
.
size
(
1
);
T
*
workspace
=
(
T
*
)
Context
::
Instance
().
GetWorkSpace
();
workspace
+=
(
3
*
bsz
*
input
.
size
(
2
));
ds_layer_norm_internal
<
T
>
(
workspace
,
input
,
gamma
,
beta
,
epsilon
);
if
(
q_int8
)
{
quantized_gemm
<
T
>
(
output
.
data_ptr
(),
workspace
,
weight
,
q_scale
,
q_scale
.
size
(
0
),
bsz
,
input
.
size
(
2
));
}
else
{
float
alpha
=
(
T
)
1.0
;
float
gemm_beta
=
(
T
)
0.0
;
rocblas_set_stream
(
Context
::
Instance
().
GetCublasHandle
(),
Context
::
Instance
().
GetCurrentStream
());
cublas_gemm_ex
(
Context
::
Instance
().
GetCublasHandle
(),
rocblas_operation_none
,
rocblas_operation_none
,
weight
.
size
(
1
),
bsz
,
input
.
size
(
2
),
&
alpha
,
&
gemm_beta
,
(
T
*
)
weight
.
data_ptr
(),
workspace
,
(
T
*
)
output
.
data_ptr
(),
#ifdef __HIP_PLATFORM_HCC__
rocblas_gemm_algo_standard
);
#else
CUBLAS_GEMM_DEFAULT_TENSOR_OP
);
#endif
}
if
(
add_bias
)
launch_bias_add
((
T
*
)
output
.
data_ptr
(),
(
T
*
)
bias
.
data_ptr
(),
q_int8
?
weight
.
size
(
0
)
:
weight
.
size
(
1
),
bsz
,
Context
::
Instance
().
GetCurrentStream
());
return
torch
::
from_blob
(
workspace
,
input
.
sizes
(),
input
.
options
());
}
template
<
typename
T
>
std
::
vector
<
at
::
Tensor
>
ds_qkv_gemm
(
at
::
Tensor
&
input
,
at
::
Tensor
&
weight
,
at
::
Tensor
&
q_scale
,
at
::
Tensor
&
bias
,
at
::
Tensor
&
gamma
,
at
::
Tensor
&
beta
,
const
float
epsilon
,
bool
add_bias
,
unsigned
num_layers
,
bool
external_cache
,
unsigned
mp_size
,
unsigned
rank
,
bool
q_int8
)
{
int
bsz
=
input
.
size
(
0
)
*
input
.
size
(
1
);
T
*
workspace
=
(
T
*
)
Context
::
Instance
().
GetWorkSpace
();
int
out_size
=
q_int8
?
weight
.
size
(
0
)
:
weight
.
size
(
1
);
auto
options
=
at
::
TensorOptions
()
.
dtype
(
input
.
options
().
dtype
())
.
layout
(
at
::
kStrided
)
.
device
(
at
::
kCUDA
)
.
requires_grad
(
false
);
auto
output
=
at
::
from_blob
(
workspace
,
{
input
.
size
(
0
),
input
.
size
(
1
),
out_size
},
options
);
auto
inp_norm
=
qkv_unfused_cublas
<
T
>
(
output
,
input
,
weight
,
q_scale
,
bias
,
gamma
,
beta
,
epsilon
,
add_bias
,
q_int8
);
return
{
output
,
inp_norm
};
}
template
<
typename
T
>
void
quantized_gemm
(
at
::
Tensor
&
output
,
at
::
Tensor
&
input
,
at
::
Tensor
&
weight
,
at
::
Tensor
&
qscale
,
int
groups
,
int
merge_count
)
{
int
bsz
=
input
.
size
(
0
)
*
input
.
size
(
1
);
auto
options
=
at
::
TensorOptions
()
.
dtype
(
input
.
options
().
dtype
())
.
layout
(
at
::
kStrided
)
.
device
(
at
::
kCUDA
)
.
requires_grad
(
false
);
auto
weight16
=
at
::
empty
({
weight
.
size
(
0
),
weight
.
size
(
1
)},
options
);
launch_dequantize
((
T
*
)
weight16
.
data_ptr
(),
(
int8_t
*
)
weight
.
data_ptr
(),
(
float
*
)
qscale
.
data_ptr
(),
weight
.
size
(
0
),
weight
.
size
(
1
),
groups
,
merge_count
,
Context
::
Instance
().
GetCurrentStream
());
float
alpha
=
(
T
)
1.0
;
float
gemm_beta
=
(
T
)
0.0
;
cublas_gemm_ex
(
Context
::
Instance
().
GetCublasHandle
(),
rocblas_operation_transpose
,
rocblas_operation_none
,
weight
.
size
(
0
),
bsz
,
input
.
size
(
2
),
&
alpha
,
&
gemm_beta
,
(
T
*
)
weight16
.
data_ptr
(),
(
T
*
)
input
.
data_ptr
(),
(
T
*
)
output
.
data_ptr
(),
#ifdef __HIP_PLATFORM_HCC__
rocblas_gemm_algo_standard
);
#else
CUBLAS_GEMM_DEFAULT_TENSOR_OP
);
#endif
}
template
<
typename
T
>
at
::
Tensor
ds_qkv_gemm_int8
(
at
::
Tensor
&
input
,
at
::
Tensor
&
weight
,
at
::
Tensor
&
bias
,
at
::
Tensor
&
gamma
,
at
::
Tensor
&
beta
,
const
float
epsilon
,
at
::
Tensor
&
q_scale
,
int
groups
,
bool
add_bias
)
{
int
bsz
=
input
.
size
(
0
)
*
input
.
size
(
1
);
auto
input_cont
=
input
.
contiguous
();
auto
options
=
at
::
TensorOptions
()
.
dtype
(
input_cont
.
options
().
dtype
())
.
layout
(
at
::
kStrided
)
.
device
(
at
::
kCUDA
)
.
requires_grad
(
false
);
auto
output
=
at
::
empty
({
input_cont
.
size
(
0
),
input_cont
.
size
(
1
),
weight
.
size
(
1
)},
options
);
auto
inp_norm
=
ds_layer_norm
(
input_cont
,
gamma
,
beta
,
epsilon
);
quantized_gemm
<
T
>
(
output
,
inp_norm
,
weight
,
q_scale
,
groups
,
0
);
if
(
add_bias
)
launch_bias_add
((
T
*
)
output
.
data_ptr
(),
(
T
*
)
bias
.
data_ptr
(),
weight
.
size
(
1
),
bsz
,
Context
::
Instance
().
GetCurrentStream
());
return
output
;
}
template
<
typename
T
>
at
::
Tensor
ds_linear_layer
(
at
::
Tensor
&
input
,
at
::
Tensor
&
weight
,
at
::
Tensor
&
bias
,
bool
add_bias
,
bool
do_flash_attn
,
int
num_heads
)
{
auto
input_cont
=
input
.
contiguous
();
auto
options
=
at
::
TensorOptions
()
.
dtype
(
input_cont
.
options
().
dtype
())
.
layout
(
at
::
kStrided
)
.
device
(
at
::
kCUDA
)
.
requires_grad
(
false
);
int
head_size
=
input_cont
.
size
(
2
)
/
num_heads
;
int
bsz
=
input
.
size
(
0
)
*
input
.
size
(
1
);
T
*
workspace
=
(
T
*
)
Context
::
Instance
().
GetWorkSpace
();
auto
output
=
at
::
from_blob
(
workspace
,
{
input
.
size
(
0
),
input
.
size
(
1
),
weight
.
size
(
1
)},
options
);
float
alpha
=
(
T
)
1.0
;
float
gemm_beta
=
(
T
)
0.0
;
rocblas_set_stream
(
Context
::
Instance
().
GetCublasHandle
(),
Context
::
Instance
().
GetCurrentStream
());
cublas_gemm_ex
(
Context
::
Instance
().
GetCublasHandle
(),
rocblas_operation_none
,
rocblas_operation_none
,
weight
.
size
(
1
),
bsz
,
input_cont
.
size
(
2
),
&
alpha
,
&
gemm_beta
,
(
T
*
)
weight
.
data_ptr
(),
(
T
*
)
input_cont
.
data_ptr
(),
(
T
*
)
output
.
data_ptr
(),
#ifdef __HIP_PLATFORM_HCC__
rocblas_gemm_algo_standard
);
#else
CUBLAS_GEMM_DEFAULT_TENSOR_OP
);
#endif
if
(
add_bias
)
launch_bias_add
((
T
*
)
output
.
data_ptr
(),
(
T
*
)
bias
.
data_ptr
(),
weight
.
size
(
1
),
bsz
,
Context
::
Instance
().
GetCurrentStream
());
bool
add_padding
=
(
head_size
%
32
!=
0
&&
head_size
<
64
)
||
(
head_size
%
64
!=
0
);
if
(
do_flash_attn
)
{
if
(
add_padding
)
{
int
padded_head_size
=
head_size
<
32
?
32
:
(
head_size
<
64
?
64
:
128
);
auto
padded_output
=
workspace
+
output
.
numel
();
auto
final_output
=
padded_output
+
(
input
.
size
(
0
)
*
input
.
size
(
1
)
*
3
*
num_heads
*
padded_head_size
);
pad_data
(
padded_output
,
workspace
,
3
*
bsz
*
num_heads
,
head_size
,
padded_head_size
,
Context
::
Instance
().
GetCurrentStream
());
launch_bias_add_transform_0213
<
T
>
(
final_output
,
final_output
+
(
input
.
size
(
0
)
*
input
.
size
(
1
)
*
num_heads
*
padded_head_size
),
final_output
+
(
input
.
size
(
0
)
*
input
.
size
(
1
)
*
2
*
num_heads
*
padded_head_size
),
padded_output
,
nullptr
,
input
.
size
(
0
),
input
.
size
(
1
),
0
,
input
.
size
(
1
),
(
num_heads
*
padded_head_size
),
num_heads
,
-
1
,
false
,
false
,
Context
::
Instance
().
GetCurrentStream
(),
3
,
input
.
size
(
1
));
return
at
::
from_blob
(
final_output
,
{
3
,
input
.
size
(
0
),
num_heads
,
input
.
size
(
1
),
padded_head_size
},
options
);
// return at::from_blob(padded_output, {input.size(0) * input.size(1), 3, num_heads,
// padded_head_size}, options);
}
else
{
auto
final_output
=
workspace
+
output
.
numel
();
launch_bias_add_transform_0213
<
T
>
(
final_output
,
final_output
+
(
input
.
size
(
0
)
*
input
.
size
(
1
)
*
input_cont
.
size
(
2
)),
final_output
+
(
input
.
size
(
0
)
*
input
.
size
(
1
)
*
2
*
input_cont
.
size
(
2
)),
workspace
,
nullptr
,
input
.
size
(
0
),
input
.
size
(
1
),
0
,
input
.
size
(
1
),
input_cont
.
size
(
2
),
num_heads
,
-
1
,
false
,
false
,
Context
::
Instance
().
GetCurrentStream
(),
3
,
input
.
size
(
1
));
return
at
::
from_blob
(
final_output
,
{
3
,
input
.
size
(
0
),
num_heads
,
input
.
size
(
1
),
head_size
},
options
);
// return at::from_blob(workspace, {input.size(0) * input.size(1), 3, num_heads,
// head_size}, options);
}
}
else
return
output
;
}
template
<
typename
T
>
std
::
vector
<
at
::
Tensor
>
add_padding
(
at
::
Tensor
&
query
,
at
::
Tensor
&
key
,
at
::
Tensor
&
value
)
{
int
head_size
=
query
.
size
(
3
);
int
padded_head_size
=
head_size
<
32
?
32
:
(
head_size
<
64
?
64
:
128
);
T
*
workspace
=
(
T
*
)
Context
::
Instance
().
GetWorkSpace
();
T
*
key_pad_ptr
=
workspace
+
padded_head_size
*
query
.
size
(
0
)
*
query
.
size
(
1
)
*
query
.
size
(
2
);
T
*
value_pad_ptr
=
key_pad_ptr
+
padded_head_size
*
query
.
size
(
0
)
*
query
.
size
(
1
)
*
128
;
pad_head_seq
(
workspace
,
(
T
*
)
query
.
data_ptr
(),
query
.
size
(
0
)
*
query
.
size
(
1
),
query
.
size
(
2
),
query
.
size
(
2
),
head_size
,
padded_head_size
,
Context
::
Instance
().
GetCurrentStream
());
pad_head_seq
(
key_pad_ptr
,
(
T
*
)
key
.
data_ptr
(),
query
.
size
(
0
)
*
query
.
size
(
1
),
key
.
size
(
2
),
128
,
head_size
,
padded_head_size
,
Context
::
Instance
().
GetCurrentStream
());
pad_head_seq
(
value_pad_ptr
,
(
T
*
)
value
.
data_ptr
(),
query
.
size
(
0
)
*
query
.
size
(
1
),
key
.
size
(
2
),
128
,
head_size
,
padded_head_size
,
Context
::
Instance
().
GetCurrentStream
());
return
{
at
::
from_blob
(
workspace
,
{
query
.
size
(
0
),
query
.
size
(
1
),
query
.
size
(
2
),
padded_head_size
},
query
.
options
()),
at
::
from_blob
(
key_pad_ptr
,
{
query
.
size
(
0
),
query
.
size
(
1
),
128
,
padded_head_size
},
query
.
options
()),
at
::
from_blob
(
value_pad_ptr
,
{
query
.
size
(
0
),
query
.
size
(
1
),
128
,
padded_head_size
},
query
.
options
())};
}
template
<
typename
T
>
std
::
vector
<
at
::
Tensor
>
padd_add_transform
(
at
::
Tensor
&
query
,
at
::
Tensor
&
key
,
at
::
Tensor
&
value
,
int
heads
,
bool
add_padding
)
{
int
head_size
=
query
.
size
(
2
)
/
heads
;
int
key_value_length
=
add_padding
?
128
:
key
.
size
(
1
);
int
padded_head_size
=
add_padding
?
(
head_size
<
32
?
32
:
(
head_size
<
64
?
64
:
128
))
:
head_size
;
T
*
workspace
=
(
T
*
)
Context
::
Instance
().
GetWorkSpace
();
T
*
key_pad_ptr
=
workspace
+
padded_head_size
*
query
.
size
(
0
)
*
heads
*
query
.
size
(
1
);
T
*
value_pad_ptr
=
key_pad_ptr
+
padded_head_size
*
query
.
size
(
0
)
*
heads
*
key_value_length
;
launch_pad_add_transform_0213
(
workspace
,
(
T
*
)
query
.
data_ptr
(),
query
.
size
(
0
),
query
.
size
(
2
),
query
.
size
(
1
),
query
.
size
(
1
),
heads
,
padded_head_size
,
Context
::
Instance
().
GetCurrentStream
());
launch_pad_add_transform_0213
(
key_pad_ptr
,
(
T
*
)
key
.
data_ptr
(),
key
.
size
(
0
),
key
.
size
(
2
),
key
.
size
(
1
),
key_value_length
,
heads
,
padded_head_size
,
Context
::
Instance
().
GetCurrentStream
());
launch_pad_add_transform_0213
(
value_pad_ptr
,
(
T
*
)
value
.
data_ptr
(),
value
.
size
(
0
),
value
.
size
(
2
),
value
.
size
(
1
),
key_value_length
,
heads
,
padded_head_size
,
Context
::
Instance
().
GetCurrentStream
());
return
{
at
::
from_blob
(
workspace
,
{
query
.
size
(
0
),
heads
,
query
.
size
(
1
),
padded_head_size
},
query
.
options
()),
at
::
from_blob
(
key_pad_ptr
,
{
query
.
size
(
0
),
heads
,
key_value_length
,
padded_head_size
},
query
.
options
()),
at
::
from_blob
(
value_pad_ptr
,
{
query
.
size
(
0
),
heads
,
key_value_length
,
padded_head_size
},
query
.
options
())};
}
template
<
typename
T
>
at
::
Tensor
ds_linear_layer_int8
(
at
::
Tensor
&
input
,
at
::
Tensor
&
weight
,
at
::
Tensor
&
bias
,
at
::
Tensor
&
q_scale
,
int
groups
)
{
auto
input_cont
=
input
.
contiguous
();
auto
options
=
at
::
TensorOptions
()
.
dtype
(
input_cont
.
options
().
dtype
())
.
layout
(
at
::
kStrided
)
.
device
(
at
::
kCUDA
)
.
requires_grad
(
false
);
int
bsz
=
input_cont
.
size
(
0
)
*
input_cont
.
size
(
1
);
auto
output
=
at
::
empty
({
input_cont
.
size
(
0
),
input_cont
.
size
(
1
),
weight
.
size
(
1
)},
options
);
quantized_gemm
<
T
>
(
output
,
input_cont
,
weight
,
q_scale
,
groups
,
0
);
launch_bias_add
((
T
*
)
output
.
data_ptr
(),
(
T
*
)
bias
.
data_ptr
(),
weight
.
size
(
1
),
bsz
,
Context
::
Instance
().
GetCurrentStream
());
return
output
;
}
template
<
typename
T
>
at
::
Tensor
ds_vector_matmul
(
at
::
Tensor
&
input
,
at
::
Tensor
&
weight
,
bool
async_op
,
at
::
Tensor
&
q_scale
,
bool
q_int8
)
{
auto
options
=
at
::
TensorOptions
()
.
dtype
(
input
.
options
().
dtype
())
.
layout
(
at
::
kStrided
)
.
device
(
at
::
kCUDA
)
.
requires_grad
(
false
);
int
out_size
=
q_int8
?
weight
.
size
(
0
)
:
weight
.
size
(
1
);
int
bsz
=
input
.
size
(
0
)
*
input
.
size
(
1
);
T
*
workspace
=
(
T
*
)
Context
::
Instance
().
GetWorkSpace
();
auto
output
=
at
::
from_blob
(
workspace
,
{
input
.
size
(
0
),
input
.
size
(
1
),
out_size
},
options
);
if
(
q_int8
)
{
quantized_gemm
<
T
>
(
output
.
data_ptr
(),
(
T
*
)
input
.
data_ptr
(),
weight
,
q_scale
,
q_scale
.
size
(
0
),
bsz
,
input
.
size
(
2
));
}
else
{
float
alpha
=
(
T
)
1.0
;
float
gemm_beta
=
(
T
)
0.0
;
rocblas_set_stream
(
Context
::
Instance
().
GetCublasHandle
(),
Context
::
Instance
().
GetCurrentStream
(
async_op
));
cublas_gemm_ex
(
Context
::
Instance
().
GetCublasHandle
(),
rocblas_operation_none
,
rocblas_operation_none
,
weight
.
size
(
1
),
bsz
,
input
.
size
(
2
),
&
alpha
,
&
gemm_beta
,
(
T
*
)
weight
.
data_ptr
(),
(
T
*
)
input
.
data_ptr
(),
(
T
*
)
output
.
data_ptr
(),
#ifdef __HIP_PLATFORM_HCC__
rocblas_gemm_algo_standard
);
#else
CUBLAS_GEMM_DEFAULT_TENSOR_OP
);
#endif
}
return
output
;
}
template
<
typename
T
>
at
::
Tensor
ds_vector_matmul_int8
(
at
::
Tensor
&
input
,
at
::
Tensor
&
weight
,
at
::
Tensor
&
q_scale
,
int
groups
,
int
merge_count
)
{
auto
input_cont
=
input
.
contiguous
();
auto
options
=
at
::
TensorOptions
()
.
dtype
(
input_cont
.
options
().
dtype
())
.
layout
(
at
::
kStrided
)
.
device
(
at
::
kCUDA
)
.
requires_grad
(
false
);
auto
output
=
at
::
empty
({
input_cont
.
size
(
0
),
input_cont
.
size
(
1
),
weight
.
size
(
1
)},
options
);
quantized_gemm
<
T
>
(
output
,
input_cont
,
weight
,
q_scale
,
groups
,
merge_count
);
return
output
;
}
template
<
typename
T
>
at
::
Tensor
mlp_unfused_cublas
(
at
::
Tensor
&
output
,
at
::
Tensor
&
input
,
at
::
Tensor
&
residual
,
at
::
Tensor
&
input_bias
,
at
::
Tensor
&
weight
,
at
::
Tensor
&
weight1
,
at
::
Tensor
&
bias
,
at
::
Tensor
&
gamma
,
at
::
Tensor
&
beta
,
const
float
epsilon
,
bool
preLayerNorm
,
bool
mlp_after_attn
,
at
::
Tensor
&
q_scale
,
at
::
Tensor
&
q_scale1
,
bool
q_int8
,
ActivationFuncType
act_func_type
)
{
int
bsz
=
input
.
size
(
0
)
*
input
.
size
(
1
);
T
*
inp_norm
=
(
T
*
)
Context
::
Instance
().
GetWorkSpace
()
+
torch
::
numel
(
input
)
+
torch
::
numel
(
output
);
T
*
intermediate
=
inp_norm
+
torch
::
numel
(
input
);
if
(
mlp_after_attn
)
{
launch_fused_residual_ln
((
T
*
)
inp_norm
,
(
const
T
*
)
input
.
data_ptr
(),
(
const
T
*
)
residual
.
data_ptr
(),
(
const
T
*
)
input_bias
.
data_ptr
(),
(
const
T
*
)
gamma
.
data_ptr
(),
(
const
T
*
)
beta
.
data_ptr
(),
epsilon
,
bsz
,
input
.
size
(
2
),
Context
::
Instance
().
GetCurrentStream
());
}
else
{
ds_layer_norm_internal
(
inp_norm
,
input
,
gamma
,
beta
,
epsilon
);
}
if
(
q_int8
)
{
quantized_gemm
<
T
>
(
intermediate
,
inp_norm
,
weight
,
q_scale
,
q_scale
.
size
(
0
),
bsz
,
input
.
size
(
2
));
}
else
{
float
alpha
=
(
T
)
1.0
;
float
gemm_beta
=
(
T
)
0.0
;
rocblas_set_stream
(
Context
::
Instance
().
GetCublasHandle
(),
Context
::
Instance
().
GetCurrentStream
());
cublas_gemm_ex
(
Context
::
Instance
().
GetCublasHandle
(),
rocblas_operation_none
,
rocblas_operation_none
,
weight
.
size
(
1
),
bsz
,
input
.
size
(
2
),
&
alpha
,
&
gemm_beta
,
(
T
*
)
weight
.
data_ptr
(),
inp_norm
,
intermediate
,
#ifdef __HIP_PLATFORM_HCC__
rocblas_gemm_algo_standard
);
#else
CUBLAS_GEMM_DEFAULT_TENSOR_OP
);
#endif
}
if
(
act_func_type
==
ActivationFuncType
::
GELU
)
{
launch_bias_gelu
(
intermediate
,
(
T
*
)
bias
.
data_ptr
(),
q_int8
?
weight
.
size
(
0
)
:
weight
.
size
(
1
),
bsz
,
Context
::
Instance
().
GetCurrentStream
());
}
else
if
(
act_func_type
==
ActivationFuncType
::
ReLU
)
{
launch_bias_relu
(
intermediate
,
(
T
*
)
bias
.
data_ptr
(),
q_int8
?
weight
.
size
(
0
)
:
weight
.
size
(
1
),
bsz
,
Context
::
Instance
().
GetCurrentStream
());
}
if
(
q_int8
)
{
quantized_gemm
<
T
>
(
output
.
data_ptr
(),
intermediate
,
weight1
,
q_scale1
,
q_scale1
.
size
(
0
),
bsz
,
input
.
size
(
2
));
}
else
{
float
alpha
=
(
T
)
1.0
;
float
gemm_beta
=
(
T
)
0.0
;
rocblas_set_stream
(
Context
::
Instance
().
GetCublasHandle
(),
Context
::
Instance
().
GetCurrentStream
());
cublas_gemm_ex
(
Context
::
Instance
().
GetCublasHandle
(),
rocblas_operation_none
,
rocblas_operation_none
,
weight1
.
size
(
1
),
bsz
,
weight1
.
size
(
0
),
&
alpha
,
&
gemm_beta
,
(
T
*
)
weight1
.
data_ptr
(),
intermediate
,
(
T
*
)
output
.
data_ptr
(),
#ifdef __HIP_PLATFORM_HCC__
rocblas_gemm_algo_standard
);
#else
CUBLAS_GEMM_DEFAULT_TENSOR_OP
);
#endif
}
return
torch
::
from_blob
(
inp_norm
,
input
.
sizes
(),
input
.
options
());
}
template
<
typename
T
>
std
::
vector
<
at
::
Tensor
>
ds_mlp_gemm
(
at
::
Tensor
&
input
,
at
::
Tensor
&
residual
,
at
::
Tensor
&
input_bias
,
at
::
Tensor
&
weight_interm
,
at
::
Tensor
&
weight_out
,
at
::
Tensor
&
bias
,
at
::
Tensor
&
gamma
,
at
::
Tensor
&
beta
,
const
float
epsilon
,
bool
preLayerNorm
,
bool
mlp_after_attn
,
at
::
Tensor
&
q_scale
,
at
::
Tensor
&
q_scale1
,
bool
q_int8
,
int
activation_type
)
{
auto
options
=
at
::
TensorOptions
()
.
dtype
(
input
.
options
().
dtype
())
.
layout
(
at
::
kStrided
)
.
device
(
at
::
kCUDA
)
.
requires_grad
(
false
);
int
out_size
=
q_int8
?
weight_out
.
size
(
0
)
:
weight_out
.
size
(
1
);
auto
output
=
at
::
from_blob
((
T
*
)
Context
::
Instance
().
GetWorkSpace
()
+
torch
::
numel
(
input
),
{
input
.
size
(
0
),
input
.
size
(
1
),
out_size
},
options
);
int
bsz
=
input
.
size
(
0
)
*
input
.
size
(
1
);
auto
act_func_type
=
static_cast
<
ActivationFuncType
>
(
activation_type
);
auto
res_add
=
mlp_unfused_cublas
<
T
>
(
output
,
mlp_after_attn
?
input
:
residual
,
residual
,
input_bias
,
weight_interm
,
weight_out
,
bias
,
gamma
,
beta
,
epsilon
,
preLayerNorm
,
mlp_after_attn
,
q_scale
,
q_scale1
,
q_int8
,
act_func_type
);
return
{
output
,
res_add
};
}
template
<
typename
T
>
std
::
vector
<
at
::
Tensor
>
ds_mlp_gemm_int8
(
at
::
Tensor
&
input
,
at
::
Tensor
&
residual
,
at
::
Tensor
&
input_bias
,
at
::
Tensor
&
weight
,
at
::
Tensor
&
bias
,
at
::
Tensor
&
gamma
,
at
::
Tensor
&
beta
,
const
float
epsilon
,
at
::
Tensor
&
q_scale
,
int
groups
,
bool
preLayerNorm
)
{
auto
input_cont
=
input
.
contiguous
();
auto
options
=
at
::
TensorOptions
()
.
dtype
(
input_cont
.
options
().
dtype
())
.
layout
(
at
::
kStrided
)
.
device
(
at
::
kCUDA
)
.
requires_grad
(
false
);
auto
output
=
at
::
empty
({
input_cont
.
size
(
0
),
input_cont
.
size
(
1
),
weight
.
size
(
1
)},
options
);
int
bsz
=
input_cont
.
size
(
0
)
*
input_cont
.
size
(
1
);
auto
inp_norm
=
at
::
empty_like
(
input_cont
);
auto
residual_add
=
(
preLayerNorm
?
at
::
empty_like
(
input_cont
)
:
inp_norm
);
quantized_gemm
<
T
>
(
output
,
inp_norm
,
weight
,
q_scale
,
groups
,
0
);
launch_bias_gelu
((
T
*
)
output
.
data_ptr
(),
(
T
*
)
bias
.
data_ptr
(),
weight
.
size
(
1
),
bsz
,
Context
::
Instance
().
GetCurrentStream
());
return
{
output
,
residual_add
};
}
template
<
typename
T
>
at
::
Tensor
fused_gemm_gelu
(
at
::
Tensor
&
input
,
at
::
Tensor
&
weight
,
at
::
Tensor
&
weight_scale
,
at
::
Tensor
&
bias
,
at
::
Tensor
&
weight_out
,
at
::
Tensor
&
weight_out_scale
,
const
float
epsilon
,
bool
preLayerNorm
,
bool
q_int8
,
bool
async_op
)
{
auto
options
=
at
::
TensorOptions
()
.
dtype
(
input
.
options
().
dtype
())
.
layout
(
at
::
kStrided
)
.
device
(
at
::
kCUDA
)
.
requires_grad
(
false
);
int
intm_dim
=
q_int8
?
weight
.
size
(
0
)
:
weight
.
size
(
1
);
// auto output = at::from_blob((T*)Context::Instance().GetWorkSpace() + torch::numel(input),
// {input.size(0), input.size(1), out_size},
// options);
// T* intermediate = (T*)input.data_ptr() + torch::numel(input);
auto
intermediate
=
at
::
empty
({
input
.
size
(
0
),
input
.
size
(
1
),
intm_dim
},
options
);
int
bsz
=
input
.
size
(
0
)
*
input
.
size
(
1
);
float
alpha
=
(
T
)
1.0
;
float
gemm_beta
=
(
T
)
0.0
;
if
(
q_int8
)
{
quantized_gemm
<
T
>
(
intermediate
.
data_ptr
(),
(
T
*
)
input
.
data_ptr
(),
weight
,
weight_scale
,
weight_scale
.
size
(
0
),
bsz
,
input
.
size
(
2
));
}
else
{
rocblas_set_stream
(
Context
::
Instance
().
GetCublasHandle
(),
Context
::
Instance
().
GetCurrentStream
());
cublas_gemm_ex
(
Context
::
Instance
().
GetCublasHandle
(),
rocblas_operation_none
,
rocblas_operation_none
,
intm_dim
,
bsz
,
input
.
size
(
2
),
&
alpha
,
&
gemm_beta
,
(
T
*
)
weight
.
data_ptr
(),
(
T
*
)
input
.
data_ptr
(),
(
T
*
)
intermediate
.
data_ptr
(),
#ifdef __HIP_PLATFORM_HCC__
rocblas_gemm_algo_standard
);
#else
CUBLAS_GEMM_DEFAULT_TENSOR_OP
);
#endif
}
launch_bias_gelu
((
T
*
)
intermediate
.
data_ptr
(),
(
T
*
)
bias
.
data_ptr
(),
intm_dim
,
bsz
,
Context
::
Instance
().
GetCurrentStream
());
int
out_size
=
q_int8
?
weight_out
.
size
(
0
)
:
weight_out
.
size
(
1
);
auto
output
=
at
::
empty
({
input
.
size
(
0
),
input
.
size
(
1
),
out_size
},
options
);
if
(
q_int8
)
{
quantized_gemm
<
T
>
(
output
.
data_ptr
(),
(
T
*
)
intermediate
.
data_ptr
(),
weight_out
,
weight_out_scale
,
weight_out_scale
.
size
(
0
),
bsz
,
input
.
size
(
2
));
}
else
{
cublas_gemm_ex
(
Context
::
Instance
().
GetCublasHandle
(),
rocblas_operation_none
,
rocblas_operation_none
,
out_size
,
bsz
,
intm_dim
,
&
alpha
,
&
gemm_beta
,
(
T
*
)
weight_out
.
data_ptr
(),
(
T
*
)
intermediate
.
data_ptr
(),
(
T
*
)
output
.
data_ptr
(),
#ifdef __HIP_PLATFORM_HCC__
rocblas_gemm_algo_standard
);
#else
CUBLAS_GEMM_DEFAULT_TENSOR_OP
);
#endif
}
// hipEventRecord(Context::Instance().GetCompEvent(2),
// Context::Instance().GetCurrentStream(true));
return
output
;
}
template
<
typename
T
>
at
::
Tensor
&
residual_add_bias
(
at
::
Tensor
&
hidden_state
,
at
::
Tensor
&
residual
,
const
at
::
Tensor
&
attention_output
,
const
at
::
Tensor
&
attention_bias
,
const
at
::
Tensor
&
final_bias
,
const
int
mp_size
,
const
bool
mlp_after_attn
,
const
bool
add_bias
,
const
bool
preln
)
{
int
bsz
=
residual
.
size
(
0
)
*
residual
.
size
(
1
);
int
hidden_size
=
residual
.
size
(
2
);
if
(
mlp_after_attn
)
launch_bias_residual
(
static_cast
<
T
*>
(
residual
.
data_ptr
()),
static_cast
<
T
*>
(
hidden_state
.
data_ptr
()),
static_cast
<
T
*>
(
attention_output
.
data_ptr
()),
static_cast
<
T
*>
(
final_bias
.
data_ptr
()),
static_cast
<
T
*>
(
attention_bias
.
data_ptr
()),
bsz
,
hidden_size
,
mp_size
,
preln
,
Context
::
Instance
().
GetCurrentStream
());
else
launch_gptj_residual_add
<
T
>
(
static_cast
<
T
*>
(
residual
.
data_ptr
()),
static_cast
<
T
*>
(
hidden_state
.
data_ptr
()),
static_cast
<
T
*>
(
attention_output
.
data_ptr
()),
static_cast
<
T
*>
(
final_bias
.
data_ptr
()),
static_cast
<
T
*>
((
add_bias
?
attention_bias
.
data_ptr
()
:
nullptr
)),
hidden_size
,
bsz
,
mp_size
,
Context
::
Instance
().
GetCurrentStream
());
return
residual
;
}
std
::
vector
<
at
::
Tensor
>
apply_rotary_pos_emb
(
at
::
Tensor
&
mixed_query
,
at
::
Tensor
&
key_layer
,
unsigned
rotary_dim
,
unsigned
offset
,
unsigned
num_heads
,
bool
rotate_half
,
bool
rotate_every_two
)
{
auto
query_cont
=
mixed_query
.
contiguous
();
auto
key_cont
=
key_layer
.
contiguous
();
unsigned
bsz
=
mixed_query
.
size
(
0
);
unsigned
head_size
=
mixed_query
.
size
(
2
)
/
num_heads
;
unsigned
seq_len
=
mixed_query
.
size
(
1
);
if
(
mixed_query
.
scalar_type
()
==
at
::
kFloat
)
launch_apply_rotary_pos_emb
<
float
>
((
float
*
)
query_cont
.
data_ptr
(),
(
float
*
)
key_cont
.
data_ptr
(),
head_size
,
seq_len
,
rotary_dim
,
offset
,
num_heads
,
bsz
,
rotate_half
,
rotate_every_two
,
Context
::
Instance
().
GetCurrentStream
(),
Context
::
Instance
().
GetMaxTokenLenght
());
else
launch_apply_rotary_pos_emb
<
__half
>
((
__half
*
)
query_cont
.
data_ptr
(),
(
__half
*
)
key_cont
.
data_ptr
(),
head_size
,
seq_len
,
rotary_dim
,
offset
,
num_heads
,
bsz
,
rotate_half
,
rotate_every_two
,
Context
::
Instance
().
GetCurrentStream
(),
Context
::
Instance
().
GetMaxTokenLenght
());
return
{
query_cont
,
key_cont
};
}
template
<
typename
T
>
at
::
Tensor
fused_gemm_gelu_int8
(
at
::
Tensor
&
input
,
at
::
Tensor
&
weight
,
at
::
Tensor
&
bias
,
const
float
epsilon
,
at
::
Tensor
&
q_scale
,
int
groups
,
bool
preLayerNorm
)
{
auto
input_cont
=
input
.
contiguous
();
auto
options
=
at
::
TensorOptions
()
.
dtype
(
input_cont
.
options
().
dtype
())
.
layout
(
at
::
kStrided
)
.
device
(
at
::
kCUDA
)
.
requires_grad
(
false
);
auto
output
=
at
::
empty
({
input_cont
.
size
(
0
),
input_cont
.
size
(
1
),
weight
.
size
(
1
)},
options
);
int
bsz
=
input_cont
.
size
(
0
)
*
input_cont
.
size
(
1
);
quantized_gemm
<
T
>
(
output
,
input_cont
,
weight
,
q_scale
,
groups
,
0
);
launch_bias_gelu
((
T
*
)
output
.
data_ptr
(),
(
T
*
)
bias
.
data_ptr
(),
weight
.
size
(
1
),
bsz
,
Context
::
Instance
().
GetCurrentStream
());
return
output
;
}
at
::
Tensor
moe_res_matmul
(
at
::
Tensor
&
moe_res
,
at
::
Tensor
&
coef
,
at
::
Tensor
&
output
)
{
int
M
=
moe_res
.
size
(
0
)
*
moe_res
.
size
(
1
);
int
N
=
moe_res
.
size
(
2
);
Context
::
Instance
().
SynchComm
();
if
(
moe_res
.
scalar_type
()
==
at
::
kFloat
)
{
launch_moe_res_matmul
<
float
>
((
float
*
)
moe_res
.
data_ptr
(),
(
float
*
)
coef
.
data_ptr
(),
(
float
*
)
output
.
data_ptr
(),
M
,
N
,
at
::
hip
::
getCurrentHIPStreamMasqueradingAsCUDA
());
}
else
{
launch_moe_res_matmul
<
__half
>
((
__half
*
)
moe_res
.
data_ptr
(),
(
__half
*
)
coef
.
data_ptr
(),
(
__half
*
)
output
.
data_ptr
(),
M
,
N
,
at
::
hip
::
getCurrentHIPStreamMasqueradingAsCUDA
());
}
return
output
;
}
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"softmax_fp32"
,
&
ds_softmax
<
float
>
,
"DeepSpeed SoftMax with fp32 (CUDA)"
);
m
.
def
(
"softmax_fp16"
,
&
ds_softmax
<
__half
>
,
"DeepSpeed SoftMax with fp16 (CUDA)"
);
m
.
def
(
"softmax_context_fp32"
,
&
ds_softmax_context
<
float
>
,
"DeepSpeed attention with fp32 (CUDA)"
);
m
.
def
(
"softmax_context_fp16"
,
&
ds_softmax_context
<
__half
>
,
"DeepSpeed attention with fp16 (CUDA)"
);
m
.
def
(
"softmax_context_int8"
,
&
ds_softmax_context1
<
__half
>
,
"DeepSpeed attention with int8 (CUDA)"
);
m
.
def
(
"bias_gelu_fp32"
,
&
ds_bias_gelu
<
float
>
,
"DeepSpeed Gelu with fp32 (CUDA)"
);
m
.
def
(
"bias_gelu_fp16"
,
&
ds_bias_gelu
<
__half
>
,
"DeepSpeed Gelu with fp16 (CUDA)"
);
m
.
def
(
"bias_geglu"
,
&
ds_bias_geglu
,
"DeepSpeed Bias GEGLU (CUDA)"
);
m
.
def
(
"bias_add_fp32"
,
&
ds_bias_add
<
float
>
,
"DeepSpeed Bias Add with fp32 (CUDA)"
);
m
.
def
(
"bias_add_fp16"
,
&
ds_bias_add
<
__half
>
,
"DeepSpeed Gelu with fp16 (CUDA)"
);
m
.
def
(
"bias_relu_fp32"
,
&
ds_bias_relu
<
float
>
,
"DeepSpeed ReLU with fp32 (CUDA)"
);
m
.
def
(
"bias_relu_fp16"
,
&
ds_bias_relu
<
__half
>
,
"DeepSpeed ReLU with fp16 (CUDA)"
);
m
.
def
(
"bias_residual_fp32"
,
&
ds_bias_residual
<
float
>
,
"DeepSpeed residual-bias add with fp32 (CUDA)"
);
m
.
def
(
"bias_residual_fp16"
,
&
ds_bias_residual
<
__half
>
,
"DeepSpeed residual-bias add with fp16 (CUDA)"
);
m
.
def
(
"layer_norm"
,
&
ds_layer_norm
,
"DeepSpeed layer norm (CUDA)"
);
m
.
def
(
"_layer_norm_residual"
,
&
ds_layer_norm_residual
,
"DeepSpeed layer norm + residual (CUDA)"
);
m
.
def
(
"layer_norm_residual_store_pre_ln_res"
,
&
ds_layer_norm_residual_store_pre_ln_res
,
"DeepSpeed layer norm + store pre Layernorm residual (CUDA)"
);
m
.
def
(
"qkv_gemm_fp32"
,
&
ds_qkv_gemm
<
float
>
,
"DeepSpeed qkv gemm with fp32 (CUDA)"
);
m
.
def
(
"qkv_gemm_fp16"
,
&
ds_qkv_gemm
<
__half
>
,
"DeepSpeed qkv gemm with fp16 (CUDA)"
);
m
.
def
(
"qkv_gemm_int8"
,
&
ds_qkv_gemm_int8
<
__half
>
,
"DeepSpeed qkv gemm with int8 (CUDA)"
);
m
.
def
(
"mlp_gemm_fp32"
,
&
ds_mlp_gemm
<
float
>
,
"DeepSpeed mlp with fp32 (CUDA)"
);
m
.
def
(
"mlp_gemm_fp16"
,
&
ds_mlp_gemm
<
__half
>
,
"DeepSpeed mlp with fp16 (CUDA)"
);
m
.
def
(
"mlp_gemm_int8"
,
&
ds_mlp_gemm_int8
<
__half
>
,
"DeepSpeed mlp with int8 (CUDA)"
);
m
.
def
(
"vector_matmul_fp32"
,
&
ds_vector_matmul
<
float
>
,
"DeepSpeed vector-MM with fp32 (CUDA)"
);
m
.
def
(
"vector_matmul_fp16"
,
&
ds_vector_matmul
<
__half
>
,
"DeepSpeed vector-MM with fp16 (CUDA)"
);
m
.
def
(
"vector_matmul_int8"
,
&
ds_vector_matmul_int8
<
__half
>
,
"DeepSpeed vector-MM with int8 (CUDA)"
);
m
.
def
(
"linear_layer_fp32"
,
&
ds_linear_layer
<
float
>
,
"DeepSpeed linear_layer with fp32 (CUDA)"
);
m
.
def
(
"linear_layer_fp16"
,
&
ds_linear_layer
<
__half
>
,
"DeepSpeed linear_layer with fp16 (CUDA)"
);
m
.
def
(
"linear_layer_int8"
,
&
ds_linear_layer_int8
<
__half
>
,
"DeepSpeed linear_layer with int8 (CUDA)"
);
m
.
def
(
"fused_gemm_gelu_fp32"
,
&
fused_gemm_gelu
<
float
>
,
"DeepSpeed mlp with fp32 (CUDA)"
);
m
.
def
(
"fused_gemm_gelu_fp16"
,
&
fused_gemm_gelu
<
__half
>
,
"DeepSpeed mlp with fp16 (CUDA)"
);
m
.
def
(
"residual_add_bias_fp32"
,
&
residual_add_bias
<
float
>
,
"DeepSpeed residual add with fp32 (CUDA)"
);
m
.
def
(
"residual_add_bias_fp16"
,
&
residual_add_bias
<
__half
>
,
"DeepSpeed residual add with fp16 (CUDA)"
);
m
.
def
(
"apply_rotary_pos_emb"
,
&
apply_rotary_pos_emb
,
"DeepSpeed mlp with fp16 (CUDA)"
);
m
.
def
(
"einsum_sec_sm_ecm_fp32"
,
&
einsum_sec_sm_ecm
<
float
>
,
"DeepSpeed vector-MM with fp32 (CUDA)"
);
m
.
def
(
"einsum_sec_sm_ecm_fp16"
,
&
einsum_sec_sm_ecm
<
__half
>
,
"DeepSpeed vector-MM with fp16 (CUDA)"
);
m
.
def
(
"moe_res_matmul"
,
&
moe_res_matmul
,
"DeepSpeed moe residual matmul (CUDA)"
);
m
.
def
(
"add_padding_fp32"
,
&
add_padding
<
float
>
,
"DeepSpeed residual add with fp32 (CUDA)"
);
m
.
def
(
"add_padding_fp16"
,
&
add_padding
<
__half
>
,
"DeepSpeed residual add with fp16 (CUDA)"
);
m
.
def
(
"pad_transform_fp32"
,
&
padd_add_transform
<
float
>
,
"DeepSpeed residual add with fp32 (CUDA)"
);
m
.
def
(
"pad_transform_fp16"
,
&
padd_add_transform
<
__half
>
,
"DeepSpeed residual add with fp16 (CUDA)"
);
m
.
def
(
"allocate_workspace_fp32"
,
&
allocate_workspace
<
float
>
,
"DeepSpeed memory allocation for GPT inference with fp32 (CUDA)"
);
m
.
def
(
"allocate_workspace_fp16"
,
&
allocate_workspace
<
__half
>
,
"DeepSpeed memory allocation for GPT inference with fp16 (CUDA)"
);
m
.
def
(
"reset_cache"
,
&
reset_cache
,
"Reset Cache for generation tasks"
);
}
csrc/transformer/inference/csrc/relu.hip
deleted
100644 → 0
View file @
7dd68788
// !!! This is a file automatically generated by hipify!!!
#include "hip/hip_runtime.h"
/*
Copyright 2022 The Microsoft DeepSpeed Team
*/
#include "conversion_utils.h"
#include "inference_cuda_layers.h"
#include "memory_access_utils.h"
namespace cg = cooperative_groups;
#define MAX_CAP 4
#define MAX_SEQ 2048
inline __device__ float relu(const float x) { return x < 0 ? 0 : x; }
/*
In-place relu(biasAdd(x)) for channels last
*/
template <typename T>
__global__ void fused_bias_relu(T* input, const T* bias, int total_count, int intermediate_size)
{
// Input restriction: intermediate_size % vals_per_access == 0
constexpr int granularity = 16;
constexpr int values_per_access = granularity / sizeof(T);
const int offset = (blockIdx.x * blockDim.x + threadIdx.x) * values_per_access;
if (offset < total_count) {
T data[values_per_access];
T data_bias[values_per_access];
mem_access::load_global<granularity>(data, input + offset);
mem_access::load_global<granularity>(data_bias, bias + (offset % intermediate_size));
#pragma unroll
for (int i = 0; i < values_per_access; i++) {
float data_f = conversion::to<float>(data[i]);
float bias_f = conversion::to<float>(data_bias[i]);
data[i] = conversion::to<T>(relu(data_f + bias_f));
}
mem_access::store_global<granularity>(input + offset, data);
}
}
template <typename T>
void launch_bias_relu(T* input,
const T* bias,
int intermediate_size,
int batch_size,
hipStream_t stream)
{
constexpr int threads = 1024;
constexpr int granularity = 16;
const int total_count = batch_size * intermediate_size;
const int elems_per_block = threads * (granularity / sizeof(T));
dim3 block_dims(threads);
dim3 grid_dims((total_count + elems_per_block - 1) / elems_per_block);
hipLaunchKernelGGL(( fused_bias_relu), dim3(grid_dims), dim3(block_dims), 0, stream,
input, bias, total_count, intermediate_size);
}
template void launch_bias_relu<float>(float*, const float*, int, int, hipStream_t);
template void launch_bias_relu<__half>(__half*, const __half*, int, int, hipStream_t);
csrc/transformer/inference/csrc/softmax.hip
deleted
100644 → 0
View file @
7dd68788
// !!! This is a file automatically generated by hipify!!!
#include "hip/hip_runtime.h"
/*
Copyright 2022 The Microsoft DeepSpeed Team
*/
#include <limits>
#include "inference_cuda_layers.h"
#ifndef __HIP_PLATFORM_HCC__
#include <cuda_profiler_api.h>
#endif
#include <cstdio>
#include <cstdlib>
#include <ctime>
#define ATTN_THREADS 256
#define MAX_REG_SIZE 8
#define minus_infinity -10000.0
void CheckCudaErrorAux(const char* file, unsigned line)
{
hipError_t err = hipGetLastError();
if (err == hipSuccess) return;
std::cerr << hipGetErrorString(err) << "(" << err << ") at " << file << ":" << line
<< std::endl;
throw std::runtime_error("CUDA ERROR!!!\n");
}
#define CUDA_CHECK_ERROR() CheckCudaErrorAux(__FILE__, __LINE__)
namespace cg = cooperative_groups;
__global__ void attn_softmax_v2(__half* vals,
__half* mask,
__half* alibi,
float layer_scale,
bool triangular,
bool recompute,
bool local_attention,
int window_size,
int total_count,
int heads,
int sequence_length,
int num_seq,
int head_offset,
int mask_stride,
int mp_size,
int iterations,
int reduceWidth)
{
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<WARP_SIZE> g = cg::tiled_partition<WARP_SIZE>(b);
float2 low_data[MAX_REG_SIZE];
float2 high_data[MAX_REG_SIZE];
const __half zero_h = __float2half(0.f);
int wid = threadIdx.x >> 5;
int lane = threadIdx.x & 0x1f;
int warp_num = blockDim.x >> 5;
int reduce_blocks = reduceWidth >> 5;
int seq_lane = threadIdx.x % reduceWidth;
__shared__ float partialSum[MAX_WARP_NUM];
int iter_offset = blockIdx.x * (warp_num / reduce_blocks) + (wid / reduce_blocks);
int batch_idx = iter_offset / (num_seq * heads);
int alibi_offset = batch_idx * heads * mp_size + head_offset;
int mask_offset = batch_idx * mask_stride + (iter_offset % mask_stride);
if (iter_offset < total_count) {
vals += (iter_offset * sequence_length);
alibi_offset = (alibi_offset + ((iter_offset / num_seq) % heads)) * sequence_length;
mask_offset = mask_offset * sequence_length;
int seq_id = iter_offset % num_seq;
int seq_id4 = seq_id >> 2;
int real_seq_id = seq_id + (num_seq == sequence_length ? 0 : sequence_length);
int window_stride4 = (local_attention && (real_seq_id >> 2) > (window_size >> 2))
? (real_seq_id >> 2) - (window_size >> 2)
: 0;
int window_stride =
(local_attention && real_seq_id >= window_size) ? real_seq_id - window_size : -1;
float max_val = minus_infinity;
// if (lane == 0) printf("%d, %d: %d \n", wid, blockIdx.x, mask_offset);
for (int i = 0; i < iterations; i++) {
int data_id = i * (reduceWidth << 2) + (seq_lane << 2);
if ((!triangular || ((data_id >> 2) <= seq_id4)) && (data_id >> 2) >= window_stride4 &&
data_id < sequence_length) {
if ((sequence_length - data_id) >= 4) {
low_data[i].x = data_id > window_stride
? __half2float(vals[data_id]) * layer_scale
: minus_infinity;
low_data[i].y = ((!triangular || ((data_id + 1) <= seq_id)) &&
(data_id + 1) > window_stride)
? __half2float(vals[data_id + 1]) * layer_scale
: minus_infinity;
high_data[i].x = ((!triangular || ((data_id + 2) <= seq_id)) &&
(data_id + 2) > window_stride)
? __half2float(vals[data_id + 2]) * layer_scale
: minus_infinity;
high_data[i].y = ((!triangular || ((data_id + 3) <= seq_id)) &&
(data_id + 3) > window_stride)
? __half2float(vals[data_id + 3]) * layer_scale
: minus_infinity;
if (alibi) {
low_data[i].x = low_data[i].x + __half2float(alibi[data_id + alibi_offset]);
low_data[i].y =
low_data[i].y + __half2float(alibi[data_id + alibi_offset + 1]);
high_data[i].x =
high_data[i].x + __half2float(alibi[data_id + alibi_offset + 2]);
high_data[i].y =
high_data[i].y + __half2float(alibi[data_id + alibi_offset + 3]);
}
if (mask) {
low_data[i].x += __half2float(mask[data_id + mask_offset]);
low_data[i].y += __half2float(mask[data_id + mask_offset + 1]);
high_data[i].x += __half2float(mask[data_id + mask_offset + 2]);
high_data[i].y += __half2float(mask[data_id + mask_offset + 3]);
}
} else {
low_data[i].x = data_id > window_stride
? __half2float(vals[data_id]) * layer_scale
: minus_infinity;
low_data[i].y = (((!triangular || (data_id + 1) <= seq_id) &&
(data_id + 1) > window_stride) &&
(data_id + 1) < sequence_length)
? __half2float(vals[data_id + 1]) * layer_scale
: minus_infinity;
high_data[i].x = (((!triangular || (data_id + 2) <= seq_id) &&
(data_id + 2) > window_stride) &&
(data_id + 2) < sequence_length)
? __half2float(vals[data_id + 2]) * layer_scale
: minus_infinity;
if (alibi) {
low_data[i].x = low_data[i].x + __half2float(alibi[data_id + alibi_offset]);
if ((data_id + 1) < sequence_length)
low_data[i].y =
low_data[i].y + __half2float(alibi[data_id + alibi_offset + 1]);
if ((data_id + 2) < sequence_length)
high_data[i].x =
high_data[i].x + __half2float(alibi[data_id + alibi_offset + 2]);
}
high_data[i].y = minus_infinity;
if (mask) {
low_data[i].x += __half2float(mask[data_id + mask_offset]);
if ((data_id + 1) < sequence_length)
low_data[i].y += __half2float(mask[data_id + mask_offset + 1]);
if ((data_id + 2) < sequence_length)
high_data[i].x += __half2float(mask[data_id + mask_offset + 2]);
}
}
// if(lane == 0) printf("%f , %d, %d \n", low_data[i].x, data_id, seq_id);
max_val = (low_data[i].x > max_val ? low_data[i].x : max_val);
max_val = (low_data[i].y > max_val ? low_data[i].y : max_val);
max_val = (high_data[i].x > max_val ? high_data[i].x : max_val);
max_val = (high_data[i].y > max_val ? high_data[i].y : max_val);
} else {
low_data[i].x = minus_infinity;
low_data[i].y = minus_infinity;
high_data[i].x = minus_infinity;
high_data[i].y = minus_infinity;
}
}
for (int i = 1; i < WARP_SIZE; i *= 2) {
auto temp = g.shfl_xor(max_val, i);
max_val = (temp > max_val ? temp : max_val);
}
if (reduceWidth > WARP_SIZE) {
if (lane == 0) partialSum[wid] = max_val;
b.sync();
if (lane < warp_num) max_val = partialSum[lane];
b.sync();
for (int i = 1; i < reduce_blocks; i *= 2) {
auto temp = g.shfl_xor(max_val, i);
max_val = (temp > max_val ? temp : max_val);
}
max_val = g.shfl(max_val, threadIdx.x / WARP_SIZE);
}
float sum = 0;
for (int i = 0; i < iterations; i++) {
low_data[i].x = __expf(low_data[i].x - max_val);
low_data[i].y = __expf(low_data[i].y - max_val);
high_data[i].x = __expf(high_data[i].x - max_val);
high_data[i].y = __expf(high_data[i].y - max_val);
sum += (low_data[i].x + low_data[i].y + high_data[i].x + high_data[i].y);
}
for (int i = 1; i < WARP_SIZE; i *= 2) sum += g.shfl_xor(sum, i);
if (reduceWidth > WARP_SIZE) {
if (lane == 0) partialSum[wid] = sum;
b.sync();
if (lane < warp_num) sum = partialSum[lane];
b.sync();
for (int i = 1; i < reduce_blocks; i *= 2) { sum += g.shfl_xor(sum, i); }
sum = g.shfl(sum, threadIdx.x / WARP_SIZE);
}
sum += 1e-6;
for (int i = 0; i < iterations; i++) {
int data_id = i * (reduceWidth << 2) + (seq_lane << 2);
if (data_id < sequence_length) {
if ((sequence_length - data_id) >= 4) {
vals[data_id] = __float2half(low_data[i].x / sum);
vals[data_id + 1] = __float2half(low_data[i].y / sum);
vals[data_id + 2] = __float2half(high_data[i].x / sum);
vals[data_id + 3] = __float2half(high_data[i].y / sum);
} else {
vals[data_id] = __float2half(low_data[i].x / sum);
if ((data_id + 1) < sequence_length)
vals[data_id + 1] = __float2half(low_data[i].y / sum);
if ((data_id + 2) < sequence_length)
vals[data_id + 2] = __float2half(high_data[i].x / sum);
}
}
}
}
}
__global__ void attn_softmax_v2(float* vals,
float* attn_mask,
float* alibi,
float layer_scale,
bool triangular,
bool recompute,
bool local_attention,
int window_size,
int total_count,
int heads,
int sequence_length,
int num_seq,
int head_offset,
int mask_stride,
int mp_size,
int iterations,
int reduceWidth)
{
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<WARP_SIZE> g = cg::tiled_partition<WARP_SIZE>(b);
float4 data[MAX_REG_SIZE];
int wid = threadIdx.x >> 5;
int lane = threadIdx.x & 0x1f;
int warp_num = blockDim.x >> 5;
int reduce_blocks = reduceWidth >> 5;
int seq_lane = threadIdx.x % reduceWidth;
__shared__ float partialSum[MAX_WARP_NUM];
int iter_offset = blockIdx.x * (warp_num / reduce_blocks) + (wid / reduce_blocks);
if (iter_offset < total_count) {
vals += (iter_offset * sequence_length);
int batch_idx = iter_offset / (num_seq * heads);
int alibi_offset = batch_idx * heads * mp_size + head_offset;
int mask_offset = batch_idx * mask_stride + (iter_offset % mask_stride);
mask_offset = mask_offset * sequence_length;
int seq_id = iter_offset % num_seq;
int seq_id4 = seq_id >> 2;
int real_seq_id = seq_id + (num_seq == sequence_length ? 0 : sequence_length);
int window_stride4 = (local_attention && (real_seq_id >> 2) > (window_size >> 2))
? (real_seq_id >> 2) - (window_size >> 2)
: 0;
int window_stride =
(local_attention && real_seq_id >= window_size) ? real_seq_id - window_size : -1;
float max_val = minus_infinity;
for (int i = 0; i < iterations; i++) {
int data_id = i * (reduceWidth << 2) + (seq_lane << 2);
if ((!triangular || ((data_id >> 2) <= seq_id4)) && (data_id >> 2) >= window_stride4 &&
data_id < sequence_length) {
if ((sequence_length - data_id) >= 4) {
data[i].x = (data_id > window_stride ? vals[data_id] : minus_infinity);
data[i].y = ((!triangular || ((data_id + 1) <= seq_id)) &&
(data_id + 1) > window_stride)
? vals[data_id + 1]
: minus_infinity;
data[i].z = ((!triangular || ((data_id + 2) <= seq_id)) &&
(data_id + 2) > window_stride)
? vals[data_id + 2]
: minus_infinity;
data[i].w = ((!triangular || ((data_id + 3) <= seq_id)) &&
(data_id + 3) > window_stride)
? vals[data_id + 3]
: minus_infinity;
if (attn_mask) {
data[i].x += attn_mask[data_id + mask_offset];
data[i].y += attn_mask[data_id + mask_offset + 1];
data[i].z += attn_mask[data_id + mask_offset + 2];
data[i].w += attn_mask[data_id + mask_offset + 3];
}
} else {
data[i].x = data_id > window_stride ? vals[data_id] : minus_infinity;
data[i].y = (((!triangular || (data_id + 1) <= seq_id)) &&
(data_id + 1) > window_stride && (data_id + 1) < sequence_length)
? (vals[data_id + 1])
: minus_infinity;
data[i].z = (((!triangular || (data_id + 2) <= seq_id)) &&
(data_id + 2) > window_stride && (data_id + 2) < sequence_length)
? (vals[data_id + 2])
: minus_infinity;
data[i].w = minus_infinity;
if (attn_mask) {
data[i].x += attn_mask[data_id + mask_offset];
if ((data_id + 1) < sequence_length)
data[i].y += attn_mask[data_id + mask_offset + 1];
if ((data_id + 2) < sequence_length)
data[i].z += attn_mask[data_id + mask_offset + 2];
}
}
max_val = (data[i].x > max_val ? data[i].x : max_val);
max_val = (data[i].y > max_val ? data[i].y : max_val);
max_val = (data[i].z > max_val ? data[i].z : max_val);
max_val = (data[i].w > max_val ? data[i].w : max_val);
} else {
data[i].x = minus_infinity;
data[i].y = minus_infinity;
data[i].z = minus_infinity;
data[i].w = minus_infinity;
}
}
for (int i = 1; i < WARP_SIZE; i *= 2) {
auto temp = g.shfl_xor(max_val, i);
max_val = (temp > max_val ? temp : max_val);
}
if (reduceWidth > WARP_SIZE) {
if (lane == 0) partialSum[wid] = max_val;
b.sync();
if (lane < warp_num) max_val = partialSum[lane];
b.sync();
for (int i = 1; i < reduce_blocks; i *= 2) {
auto temp = g.shfl_xor(max_val, i);
max_val = (temp > max_val ? temp : max_val);
}
max_val = g.shfl(max_val, threadIdx.x / WARP_SIZE);
}
float sum = 0;
for (int i = 0; i < iterations; i++) {
data[i].x = __expf(data[i].x - max_val);
data[i].y = __expf(data[i].y - max_val);
data[i].z = __expf(data[i].z - max_val);
data[i].w = __expf(data[i].w - max_val);
sum += (data[i].x + data[i].y + data[i].z + data[i].w);
}
for (int i = 1; i < WARP_SIZE; i *= 2) sum += g.shfl_xor(sum, i);
if (reduceWidth > WARP_SIZE) {
if (lane == 0) partialSum[wid] = sum;
b.sync();
if (lane < warp_num) sum = partialSum[lane];
b.sync();
for (int i = 1; i < reduce_blocks; i *= 2) { sum += g.shfl_xor(sum, i); }
sum = g.shfl(sum, threadIdx.x / WARP_SIZE);
}
sum += 1e-6;
for (int i = 0; i < iterations; i++) {
int data_id = i * (reduceWidth << 2) + (seq_lane << 2);
if (data_id < sequence_length) {
if ((sequence_length - data_id) >= 4) {
vals[data_id] = data[i].x / sum;
vals[data_id + 1] = data[i].y / sum;
vals[data_id + 2] = data[i].z / sum;
vals[data_id + 3] = data[i].w / sum;
} else {
vals[data_id] = data[i].x / sum;
if ((data_id + 1) < sequence_length) vals[data_id + 1] = data[i].y / sum;
if ((data_id + 2) < sequence_length) vals[data_id + 2] = data[i].z / sum;
}
}
}
}
}
template <typename T>
void launch_attn_softmax_v2(T* vals,
T* mask,
T* alibi,
float layer_scale,
bool triangular,
bool recompute,
bool local_attention,
int window_size,
int batch_size,
int heads,
int num_seq,
int sequence_length,
int head_offset,
int mask_stride,
int mp_size,
hipStream_t stream)
{
int total_count = batch_size * heads * num_seq;
int warp_num = ATTN_THREADS / WARP_SIZE;
int reduce_width = ((sequence_length - 1) / ATTN_THREADS + 1);
reduce_width = (int)pow(2.0, floor(log2((float)(reduce_width)))) * WARP_SIZE;
dim3 grid_dim((total_count - 1) / (ATTN_THREADS / reduce_width) + 1);
dim3 block_dim(ATTN_THREADS);
const int iterations = (sequence_length - 1) / (reduce_width << 2) + 1;
if (sequence_length <= 32768)
hipLaunchKernelGGL(( attn_softmax_v2), dim3(grid_dim), dim3(block_dim), 0, stream, vals,
mask,
alibi,
layer_scale,
triangular,
recompute,
local_attention,
window_size,
total_count,
heads,
sequence_length,
num_seq,
head_offset,
mask_stride,
mp_size,
iterations,
reduce_width);
else
throw std::runtime_error("Unsupport Seq_Length!");
}
template void launch_attn_softmax_v2(float* vals,
float* mask,
float* alibi,
float layer_scale,
bool triangular,
bool recompute,
bool local_attention,
int window_size,
int batch_size,
int heads,
int num_seq,
int sequence_length,
int head_offset,
int mask_stride,
int mp_size,
hipStream_t stream);
template void launch_attn_softmax_v2(__half* vals,
__half* mask,
__half* alibi,
float layer_scale,
bool triangular,
bool recompute,
bool local_attention,
int window_size,
int batch_size,
int heads,
int num_seq,
int sequence_length,
int head_offset,
int mask_stride,
int mp_size,
hipStream_t stream);
csrc/transformer/inference/csrc/transform.hip
deleted
100644 → 0
View file @
7dd68788
// !!! This is a file automatically generated by hipify!!!
#include "hip/hip_runtime.h"
/*
Copyright 2022 The Microsoft DeepSpeed Team
*/
#ifndef __HIP_PLATFORM_HCC__
#include <cuda_profiler_api.h>
#endif
#include "inference_cuda_layers.h"
namespace cg = cooperative_groups;
// Bias add
__global__ void bias_add_transform_0213(float* output,
float* k_cache,
float* v_cache,
const float* vals,
const float* bias,
int hidden_dim,
int seq_length,
unsigned seq_offset,
int heads,
int rotary_dim,
bool rotate_half,
bool rotate_every_two,
int head_ext,
int max_out_tokens)
{
int d0_stride = hidden_dim * seq_length;
int d1_stride = hidden_dim;
int d2_stride = hidden_dim / heads;
int d0 = blockIdx.x; // Batch
int d1 = blockIdx.y; // Sequence ID (0-127)
int cnt = blockIdx.z / head_ext; // Hidden count
int d2 = threadIdx.y + (blockIdx.z % head_ext) * (heads / head_ext); // Head (0-11)
int d3 = threadIdx.x; // Values (groups of 4)
int d2_out_stride = d2_stride * (cnt == 0 ? seq_length : max_out_tokens);
int d0_out_stride = hidden_dim * (cnt == 0 ? seq_length : max_out_tokens);
const float4* vals_vec = reinterpret_cast<const float4*>(vals);
float4* output_vec =
reinterpret_cast<float4*>(cnt == 0 ? output : (cnt == 1 ? k_cache : v_cache));
vals_vec += (d0 * d0_stride * (gridDim.z / head_ext));
vals_vec += (d1 * d1_stride * (gridDim.z / head_ext));
vals_vec += (cnt * d1_stride);
vals_vec += (d2 * d2_stride);
output_vec += (d1 * d2_stride);
output_vec += (d0 * d0_out_stride);
output_vec += (d2 * d2_out_stride);
unsigned seq_id = d1 + seq_offset;
float4 inputs = vals_vec[d3];
int lane = d3 & 0x1f;
if (cnt < 2 && rotary_dim > 0 && d3 < rotary_dim) {
float4 q = vals_vec[d3];
float2* q_f = reinterpret_cast<float2*>(&q);
if (rotate_every_two) {
#pragma unroll
for (int o = 0; o < 2; o++) {
float inv_freq = (float)(((d3 << 1) + o) * 2) / (float)(rotary_dim << 2);
inv_freq = 1.0 / powf(10000.0, inv_freq) * (float)seq_id;
q_f[o].x = (-1.0 * q_f[o].y * sinf(inv_freq) + q_f[o].x * cosf(inv_freq));
q_f[o].y = (q_f[o].x * sinf(inv_freq) + q_f[o].y * cosf(inv_freq));
}
}
output_vec[d3] = q;
} else
output_vec[d3] = inputs;
}
#define ATTN_H 3
#define MAX_SEQ_LINE 10
__global__ void bias_add_transform_0213(__half* output, // q
__half* k_cache,
__half* v_cache,
const __half* vals, // qkv
const __half* bias,
int hidden_dim,
int seq_length,
unsigned seq_offset,
int all_tokens,
int heads,
int rotary_dim,
bool rotate_half,
bool rotate_every_two,
int head_ext,
int max_out_tokens)
{
unsigned half_dim = (rotary_dim << 3) >> 1;
int d0_stride = hidden_dim * seq_length;
int d1_stride = hidden_dim;
int d2_stride = hidden_dim / heads;
int d0 = blockIdx.x; // Batch
int d1 = blockIdx.y; // Sequence ID (0-127)
int cnt = blockIdx.z / head_ext; // Hidden count
int d2 = threadIdx.y + (blockIdx.z % head_ext) * (heads / head_ext); // Head (0-11)
int d3 = threadIdx.x; // Values (groups of 4)
int d2_out_stride = d2_stride * (cnt == 0 ? seq_length : max_out_tokens);
int d0_out_stride = hidden_dim * (cnt == 0 ? seq_length : max_out_tokens);
float4 vals_arr;
float4 output_arr;
__half2* vals_half = reinterpret_cast<__half2*>(&vals_arr);
__half2* output_half = reinterpret_cast<__half2*>(&output_arr);
const float4* vals_vec = reinterpret_cast<const float4*>(vals);
float4* output_vec =
reinterpret_cast<float4*>(cnt == 0 ? output : (cnt == 1 ? k_cache : v_cache));
vals_vec += (d0 * d0_stride * (gridDim.z / head_ext));
vals_vec += (d1 * d1_stride * (gridDim.z / head_ext));
vals_vec += (cnt * d1_stride);
vals_vec += (d2 * d2_stride);
output_vec += (d1 * d2_stride);
output_vec += (d0 * d0_out_stride);
output_vec += (d2 * d2_out_stride);
unsigned seq_id = d1 + seq_offset;
int lane = d3 & 0x1f;
if (cnt < 2 && rotary_dim > 0 && d3 < rotary_dim) {
float4 q = vals_vec[d3];
__half2* q_h = reinterpret_cast<__half2*>(&q);
if (rotate_every_two) {
#pragma unroll
for (int o = 0; o < 4; o++) {
float inv_freq = (float)(((d3 << 2) + o) * 2) / (float)(rotary_dim << 3);
inv_freq = 1.0 / powf(10000.0, inv_freq) * (float)seq_id;
float q_data[2];
q_data[0] = (float)q_h[o].x;
q_data[1] = (float)q_h[o].y;
q_h[o].x = (__half)(-1.0 * q_data[1] * sinf(inv_freq) + q_data[0] * cosf(inv_freq));
q_h[o].y = (__half)(q_data[0] * sinf(inv_freq) + q_data[1] * cosf(inv_freq));
}
}
output_vec[d3] = q;
} else
output_vec[d3] = vals_vec[d3];
}
// [B S C*H] - > C * [B A S N]
template <>
void launch_bias_add_transform_0213<float>(float* output,
float* k_cache,
float* v_cache,
const float* vals,
const float* bias,
int batch_size,
int seq_length,
unsigned seq_offset,
int all_tokens,
int hidden_dim,
int heads,
int rotary_dim,
bool rotate_half,
bool rotate_every_two,
hipStream_t stream,
int trans_count,
int max_out_tokens)
{
hidden_dim >>= 2;
int head_ext = (hidden_dim - 1) / MAX_THREADS + 1;
dim3 block_dim(hidden_dim / heads, (heads / head_ext));
dim3 grid_dim(batch_size, seq_length, (trans_count * head_ext));
hipLaunchKernelGGL(( bias_add_transform_0213), dim3(grid_dim), dim3(block_dim), 0, stream, output,
k_cache,
v_cache,
vals,
bias,
hidden_dim,
seq_length,
seq_offset,
heads,
rotary_dim >> 2,
rotate_half,
rotate_every_two,
head_ext,
max_out_tokens);
}
template <typename T>
void launch_bias_add_transform_0213(T* outputs,
T* vals,
T* vals1,
const T* vals2,
const T* bias,
int batch_size,
int seq_length,
unsigned seq_offset,
int seq_length1,
int hidden_dim,
int heads,
int rotary_dim,
bool rotate_half,
bool rotate_every_two,
hipStream_t stream,
int trans_count,
int max_out_tokens);
template <>
void launch_bias_add_transform_0213<__half>(__half* output,
__half* k_cache,
__half* v_cache,
const __half* vals,
const __half* bias,
int batch_size,
int seq_length,
unsigned seq_offset,
int all_tokens,
int hidden_dim,
int heads,
int rotary_dim,
bool rotate_half,
bool rotate_every_two,
hipStream_t stream,
int trans_count,
int max_out_tokens)
{
hidden_dim >>= 3;
int head_ext = 1; // (hidden_dim - 1) / MAX_THREADS + 1;
dim3 block_dim(hidden_dim / heads, (heads / head_ext));
dim3 grid_dim(batch_size, seq_length, (trans_count * head_ext));
hipLaunchKernelGGL(( bias_add_transform_0213), dim3(grid_dim), dim3(block_dim), 0, stream, output,
k_cache,
v_cache,
vals,
bias,
hidden_dim,
seq_length,
seq_offset,
all_tokens,
heads,
rotary_dim >> 3,
rotate_half,
rotate_every_two,
head_ext,
max_out_tokens);
}
// Bias add
__global__ void pad_add_transform_0213(float* output,
const float* vals,
int hidden_dim,
int seq_length,
int padded_seq_len,
int heads,
int padded_head_size)
{
}
__global__ void pad_add_transform_0213(__half* output,
const __half* vals,
int hidden_dim,
int seq_length,
int padded_seq_len,
int heads,
int padded_head_size)
{
float4 ZERO;
const __half2 zero_h = __float2half2_rn(0.f);
__half2* ZERO_h = reinterpret_cast<__half2*>(&ZERO);
#pragma unroll
for (int i = 0; i < 4; i++) ZERO_h[i] = zero_h;
int d0_stride = hidden_dim * seq_length;
int d1_stride = hidden_dim;
int d2_stride = hidden_dim / heads;
int d0 = blockIdx.x; // Batch
int d1 = blockIdx.y * blockDim.z + threadIdx.z; // Sequence ID (0-127)
int d2 = threadIdx.y; // Head (0-11)
int d3 = threadIdx.x; // Values (groups of 4)
int d2_out_stride = padded_head_size * padded_seq_len;
int d0_out_stride = heads * d2_out_stride;
const float4* vals_vec = reinterpret_cast<const float4*>(vals);
float4* output_vec = reinterpret_cast<float4*>(output);
vals_vec += (d0 * d0_stride);
vals_vec += (d1 * d1_stride);
vals_vec += (d2 * d2_stride);
output_vec += (d1 * padded_head_size);
output_vec += (d0 * d0_out_stride);
output_vec += (d2 * d2_out_stride);
if (d3 < d2_stride && d1 < seq_length)
output_vec[d3] = vals_vec[d3];
else
output_vec[d3] = ZERO;
}
template <typename T>
void launch_pad_add_transform_0213(T* output,
const T* vals,
int batch_size,
int hidden_dim,
int seq_length,
int padded_seq_len,
int heads,
int padded_head_size,
hipStream_t stream);
// [B S C*H] - > C * [B A S N]
template <>
void launch_pad_add_transform_0213<float>(float* output,
const float* vals,
int batch_size,
int hidden_dim,
int seq_length,
int padded_seq_len,
int heads,
int padded_head_size,
hipStream_t stream)
{
}
template <>
void launch_pad_add_transform_0213<__half>(__half* output,
const __half* vals,
int batch_size,
int hidden_dim,
int seq_length,
int padded_seq_len,
int heads,
int padded_head_size,
hipStream_t stream)
{
hidden_dim >>= 3;
dim3 block_dim((padded_head_size >> 3), heads, 2);
dim3 grid_dim(batch_size, padded_seq_len / 2);
hipLaunchKernelGGL(( pad_add_transform_0213), dim3(grid_dim), dim3(block_dim), 0, stream,
output, vals, hidden_dim, seq_length, padded_seq_len, heads, padded_head_size >> 3);
}
// Bias add
template <typename T>
__global__ void bias_add_transform_0213(T* output,
const T* vals,
const T* bias,
int hidden_dim,
int seq_length,
int heads,
int head_ext);
template <>
__global__ void bias_add_transform_0213<float>(float* output,
const float* vals,
const float* bias,
int hidden_dim,
int seq_length,
int heads,
int head_ext)
{
int d0_stride = hidden_dim * seq_length;
int d1_stride = hidden_dim;
int d2_stride = hidden_dim / heads;
int d0_out_stride = d0_stride;
int d1_out_stride = d2_stride;
int d2_out_stride = d2_stride * seq_length;
int d0 = blockIdx.x; // Batch
int d1 = blockIdx.y; // Sequence ID (0-127)
int cnt = blockIdx.z / head_ext; // Hidden count
int d2 = threadIdx.y + (blockIdx.z % head_ext) * (heads / head_ext); // Head (0-11)
int d3 = threadIdx.x; // Values (groups of 4)
const float4* vals_vec = reinterpret_cast<const float4*>(vals);
const float4* bias_vec = reinterpret_cast<const float4*>(bias);
float4* output_vec = reinterpret_cast<float4*>(output);
float4 inputs = vals_vec[d0 * d0_stride * (gridDim.z / head_ext) + cnt * d1_stride +
d1 * d1_stride * (gridDim.z / head_ext) + d2 * d2_stride + d3];
float4 biases = bias_vec[cnt * d1_stride + d2 * d2_stride + d3];
float4 outputs;
outputs.x = inputs.x + biases.x;
outputs.y = inputs.y + biases.y;
outputs.z = inputs.z + biases.z;
outputs.w = inputs.w + biases.w;
output_vec[cnt * d0_out_stride * gridDim.x + d0 * d0_out_stride + d1 * d1_out_stride +
d2 * d2_out_stride + d3] = outputs;
}
template <>
__global__ void bias_add_transform_0213<__half>(__half* output,
const __half* vals,
const __half* bias,
int hidden_dim,
int seq_length,
int heads,
int head_ext)
{
int d0_stride = hidden_dim * seq_length;
int d1_stride = hidden_dim;
int d2_stride = hidden_dim / heads;
int d2_out_stride = d2_stride * seq_length;
int d0 = blockIdx.x; // Batch
int d1 = blockIdx.y; // Sequence ID (0-127)
int cnt = blockIdx.z / head_ext; // Hidden count
int d2 = threadIdx.y + (blockIdx.z % head_ext) * (heads / head_ext); // Head (0-11)
int d3 = threadIdx.x; // Values (groups of 4)
float4 vals_arr;
float4 bias_arr;
float4 output_arr;
__half2* vals_half = reinterpret_cast<__half2*>(&vals_arr);
__half2* bias_half = reinterpret_cast<__half2*>(&bias_arr);
__half2* output_half = reinterpret_cast<__half2*>(&output_arr);
const float4* vals_vec = reinterpret_cast<const float4*>(vals);
const float4* bias_vec = reinterpret_cast<const float4*>(bias);
float4* output_vec = reinterpret_cast<float4*>(output);
vals_vec += (d0 * d0_stride * (gridDim.z / head_ext));
vals_vec += (d1 * d1_stride * (gridDim.z / head_ext));
vals_vec += (cnt * d1_stride);
vals_vec += (d2 * d2_stride);
bias_vec += (cnt * d1_stride);
bias_vec += (d2 * d2_stride);
output_vec += (cnt * d0_stride * gridDim.x);
output_vec += (d1 * d2_stride);
output_vec += (d0 * d0_stride);
output_vec += (d2 * d2_out_stride);
bias_arr = bias_vec[d3];
vals_arr = vals_vec[d3];
output_half[0] = vals_half[0] + bias_half[0];
output_half[1] = vals_half[1] + bias_half[1];
output_half[2] = vals_half[2] + bias_half[2];
output_half[3] = vals_half[3] + bias_half[3];
output_vec[d3] = output_arr;
}
__global__ void bias_add_transform_0213_v2(__half* output,
const __half* vals,
const __half* bias,
int hidden_dim,
int seq_length,
int heads)
{
__shared__ float4 in_data[3072];
int d0_stride = hidden_dim * seq_length;
int d1_stride = hidden_dim;
int d2_stride = hidden_dim / heads;
int iteration_stride = d1_stride * blockDim.z; // Hidden * 3 / 8
int batch_stride = d0_stride * blockDim.z; // Hidden * S * 3 / 8
int d0_out_stride = d0_stride;
int d1_out_stride = d2_stride;
int d2_out_stride = d2_stride * seq_length;
int d0 = blockIdx.x; // Batch
int d1 = blockIdx.y; // Sequence ID (0-127)
int cnt = threadIdx.z; // blockIdx.z; // Hidden count
int d2 = threadIdx.y; // Head (0-11)
int d3 = threadIdx.x; // Values (groups of 4)
float4 vals_arr[1];
float4 bias_arr[1];
float4 output_arr[1];
__half2* vals_half = reinterpret_cast<__half2*>(vals_arr);
__half2* bias_half = reinterpret_cast<__half2*>(bias_arr);
__half2* output_half = reinterpret_cast<__half2*>(output_arr);
const float4* vals_vec = reinterpret_cast<const float4*>(vals);
const float4* bias_vec = reinterpret_cast<const float4*>(bias);
float4* output_vec = reinterpret_cast<float4*>(output);
int iter_index = cnt * d1_stride + d2 * d2_stride + d3;
int input_offset = d0 * batch_stride + d1 * (iteration_stride << 1);
bias_arr[0] = bias_vec[iter_index];
#pragma unroll
for (int iter = 0; iter < 2; iter++) {
int iter_id = iter * iteration_stride + iter_index;
vals_arr[0] = vals_vec[input_offset + iter_id];
output_half[0] = vals_half[0] + bias_half[0];
output_half[1] = vals_half[1] + bias_half[1];
output_half[2] = vals_half[2] + bias_half[2];
output_half[3] = vals_half[3] + bias_half[3];
in_data[iter_id] = output_arr[0];
}
__syncthreads();
iteration_stride = blockDim.z * (blockDim.y >> 1);
int matrix_stride = (d0_out_stride * gridDim.x);
int head_count = (d2 >> 1) + cnt * (blockDim.y >> 1);
int out_index = d0 * d0_out_stride + d1 * (d1_out_stride << 1) + d3 + (d2 % 2) * d2_stride;
#pragma unroll
for (int iter = 0; iter < 2; iter++) {
int iter_row = (iter * iteration_stride) + head_count;
int iter_offset =
(iter_row % blockDim.y) * d2_out_stride + (iter_row / blockDim.y) * matrix_stride;
output_vec[out_index + iter_offset] =
in_data[iter_row * d2_stride + d3 + (d2 % 2) * (d1_stride * blockDim.z)];
}
}
template <typename T>
__global__ void transform4d_0213(T* out,
const T* in,
int heads,
int seq_length,
int hidden_dim,
int head_ext);
template <>
__global__ void transform4d_0213<float>(float* out,
const float* in,
int heads,
int seq_length,
int hidden_dim,
int head_ext)
{
int d0_stride = hidden_dim * seq_length;
int d1_stride = d0_stride / heads;
int d2_stride = hidden_dim / heads;
int d0_out_stride = d0_stride;
int d1_out_stride = d2_stride;
int d2_out_stride = hidden_dim;
int d0 = blockIdx.x; // Batch
int d1 = blockIdx.y / ((seq_length - 1) / blockDim.y + 1); // Head
int d2 = (threadIdx.y + blockDim.y * blockIdx.y) % seq_length;
int cnt = blockIdx.z;
int d3 = threadIdx.x; // Values (groups of 8)
if (d2 < seq_length) {
const float4* in_vec = reinterpret_cast<const float4*>(in);
float4* out_vec = reinterpret_cast<float4*>(out);
float4 vals_vec = in_vec[cnt * d0_stride * gridDim.x + d0 * d0_stride + d1 * d1_stride +
d2 * d2_stride + d3];
out_vec[d0 * d0_out_stride * gridDim.z + cnt * d2_out_stride + d1 * d1_out_stride +
d2 * d2_out_stride * gridDim.z + d3] = vals_vec;
}
}
template <>
__global__ void transform4d_0213<__half>(__half* out,
const __half* in,
int heads,
int seq_length,
int hidden_dim,
int head_ext)
{
int d0_stride = hidden_dim * (seq_length / head_ext);
int d1_stride = hidden_dim;
int d2_stride = hidden_dim / heads;
int d0 = blockIdx.x; // Batch
int d1 = threadIdx.y + (blockIdx.z % head_ext) * (heads / head_ext); // Head
int d2 = blockIdx.z / head_ext; // Sequence
int cnt = blockIdx.y; // Hidden count
int d3 = threadIdx.x; // Values (groups of 8)
const float4* in_vec = reinterpret_cast<const float4*>(in);
float4* out_vec = reinterpret_cast<float4*>(out);
in_vec += (cnt * d0_stride * gridDim.x);
in_vec += (d0 * d0_stride);
in_vec += (d2 * d2_stride);
in_vec += (d1 * d2_stride * seq_length);
out_vec += (cnt * d1_stride);
out_vec += (d1 * d2_stride);
out_vec += (d0 * d0_stride * gridDim.y);
out_vec += (d2 * d1_stride * gridDim.y);
out_vec[d3] = in_vec[d3];
}
__global__ void transform4d_0213_v2(__half* out,
const __half* in,
int heads,
int seq_length,
int hidden_dim)
{
__shared__ float4 in_data[3072];
int d0_stride = hidden_dim * seq_length;
int d1_stride = hidden_dim;
int d2_stride = hidden_dim / heads;
int d0 = blockIdx.x; // Batch
int d1 = threadIdx.y; // Head
int d2 = blockIdx.y; // Sequence
int cnt = threadIdx.z; // Hidden count
int d3 = threadIdx.x; // Values (groups of 8)
const float4* in_vec = reinterpret_cast<const float4*>(in);
float4* out_vec = reinterpret_cast<float4*>(out);
int input_offset = d0 * d0_stride + d2 * (d2_stride << 1) + d3 + (d1 % 2) * d2_stride;
int head_count = (d1 >> 1) + cnt * (blockDim.y >> 1);
int iteration_stride = blockDim.z * (blockDim.y >> 1);
int matrix_stride = (d0_stride * gridDim.x);
#pragma unroll
for (int iter = 0; iter < 2; iter++) {
int iter_row = iter * iteration_stride + head_count;
int iter_offset = (iter_row % blockDim.y) * d2_stride;
in_data[d3 + iter_offset + (iter_row / blockDim.y + (d1 % 2) * blockDim.z) * d1_stride] =
in_vec[input_offset + iter_offset * seq_length +
(iter_row / blockDim.y) * matrix_stride];
}
__syncthreads();
iteration_stride = d1_stride * blockDim.z;
int iter_index = cnt * d1_stride + d1 * d2_stride + d3;
int output_offset = d0 * d0_stride * blockDim.z + d2 * (iteration_stride << 1);
#pragma unroll
for (int iter = 0; iter < 2; iter++) {
int iter_id = iter * iteration_stride + iter_index;
out_vec[output_offset + iter_id] = in_data[iter_id];
}
}
// 3 * [B A S N] - > [B S C*H]
template <>
void launch_transform4d_0213<float>(float* out,
const float* in,
int batch_size,
int heads,
int seq_length,
int hidden_dim,
hipStream_t stream,
int trans_count)
{
hidden_dim >>= 2;
dim3 grid_dims(batch_size, heads * ((seq_length - 1) / 8 + 1), trans_count);
dim3 block_dims(hidden_dim / heads, 8);
hipLaunchKernelGGL(( transform4d_0213<float>)
, dim3(grid_dims), dim3(block_dims), 0, stream, out, in, heads, seq_length, hidden_dim, 1);
}
template <>
void launch_transform4d_0213<__half>(__half* out,
const __half* in,
int batch_size,
int heads,
int seq_length,
int hidden_dim,
hipStream_t stream,
int trans_count)
{
hidden_dim >>= 3;
int head_ext = (hidden_dim - 1) / MAX_THREADS + 1;
dim3 grid_dims(batch_size, trans_count, (seq_length * head_ext));
dim3 block_dims(hidden_dim / heads, (heads / head_ext));
hipLaunchKernelGGL(( transform4d_0213<__half>)
, dim3(grid_dims), dim3(block_dims), 0, stream, out, in, heads, seq_length, hidden_dim, head_ext);
}
Prev
1
2
3
4
5
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment