Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
deepspeed
Commits
4acf0e01
Commit
4acf0e01
authored
Apr 26, 2023
by
aiss
Browse files
delete hip file
parent
7dd68788
Changes
83
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
0 additions
and
1749 deletions
+0
-1749
csrc/transformer_bak/softmax_kernels.hip
csrc/transformer_bak/softmax_kernels.hip
+0
-597
csrc/transformer_bak/transform_kernels.cu
csrc/transformer_bak/transform_kernels.cu
+0
-575
csrc/transformer_bak/transform_kernels.hip
csrc/transformer_bak/transform_kernels.hip
+0
-577
No files found.
csrc/transformer_bak/softmax_kernels.hip
deleted
100644 → 0
View file @
7dd68788
// !!! This is a file automatically generated by hipify!!!
#include "hip/hip_runtime.h"
#include <math.h>
#include "custom_hip_layers.h"
#include "general_kernels_hip.h"
namespace cg = cooperative_groups;
dim3 get_attn_softmax_grid(int batch_size, int heads, int sequence_length, int threads)
{
int seq_length4 = sequence_length / 4;
int block_compute_size =
(seq_length4 < threads ? (int)pow(2.0, floor(log2((float)(threads / seq_length4)))) : 1);
// Note that the Y and Z dimensions are limited to 65535, while X is basically unlimited:
// https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#features-and-technical-specifications
// The batch size is typically relatively small, while the sequence length could potentially be
// arbitrarily large. We therefore place the batch size second to avoid hitting the Y limit.
unsigned x = heads * sequence_length / block_compute_size;
unsigned y = batch_size;
return {x, y};
}
// Fused attention + softmax
template <int tbSize, int blockStride, int tbSeq>
__global__ void attn_softmax(float* vals,
const float* attn_mask,
int heads,
int seq_length,
int iterations)
{
__shared__ float partialSum[MAX_WARP_NUM];
int warp_num = blockDim.x >> WARP_SIZE_BITS;
int iteration_stride = blockDim.x;
int block_width = blockStride * seq_length;
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<tbSize> g = cg::tiled_partition<tbSize>(b);
int batch = blockIdx.y;
int row = blockIdx.x;
int max_threads_in_sequence = ::max(seq_length, tbSeq);
int seq_lane = threadIdx.x % max_threads_in_sequence;
int data_offset = batch * (gridDim.x * block_width) + row * block_width +
(threadIdx.x / max_threads_in_sequence) * seq_length;
int mask_offset = batch * seq_length;
int wid = threadIdx.x >> WARP_SIZE_BITS;
int lane = threadIdx.x & 0x1f;
float4* val_cast = reinterpret_cast<float4*>(vals);
const float4* attn_mask_cast = reinterpret_cast<const float4*>(attn_mask);
float4 data[MAX_THREAD_ITERATIONS];
float max_val = minus_infinity;
for (int i = 0; i < iterations; i++) {
int data_id = i * iteration_stride + seq_lane;
if (data_id < seq_length) {
float4 mask = attn_mask_cast[mask_offset + data_id];
data[i] = val_cast[data_offset + data_id];
data[i].x += mask.x;
data[i].y += mask.y;
data[i].z += mask.z;
data[i].w += mask.w;
max_val = (data[i].x > max_val ? data[i].x : max_val);
max_val = (data[i].y > max_val ? data[i].y : max_val);
max_val = (data[i].z > max_val ? data[i].z : max_val);
max_val = (data[i].w > max_val ? data[i].w : max_val);
} else {
data[i].x = minus_infinity;
data[i].y = minus_infinity;
data[i].z = minus_infinity;
data[i].w = minus_infinity;
}
}
for (int i = 1; i < tbSize; i *= 2) {
auto temp = g.shfl_xor(max_val, i);
max_val = (temp > max_val ? temp : max_val);
}
if (seq_length > tbSize) {
if (lane == 0) partialSum[wid] = max_val;
b.sync();
if (lane < warp_num) max_val = partialSum[lane];
#ifndef __STOCHASTIC_MODE__
b.sync();
#endif
int iters = warp_num;
if (seq_length < iteration_stride)
iters = warp_num / (iteration_stride / max_threads_in_sequence);
for (int i = 1; i < iters; i *= 2) {
auto temp = g.shfl_xor(max_val, i);
max_val = (temp > max_val ? temp : max_val);
}
max_val = g.shfl(max_val, threadIdx.x / tbSize);
}
float sum = 0;
for (int i = 0; i < iterations; i++) {
data[i].x = __expf(data[i].x - max_val);
data[i].y = __expf(data[i].y - max_val);
data[i].z = __expf(data[i].z - max_val);
data[i].w = __expf(data[i].w - max_val);
sum += (data[i].x + data[i].y + data[i].z + data[i].w);
}
for (int i = 1; i < tbSize; i *= 2) { sum += g.shfl_xor(sum, i); }
if (seq_length > tbSize) {
if (lane == 0) partialSum[wid] = sum;
b.sync();
if (lane < warp_num) sum = partialSum[lane];
#ifndef __STOCHASTIC_MODE__
b.sync();
#endif
int iters = warp_num;
if (seq_length < iteration_stride)
iters = warp_num / (iteration_stride / max_threads_in_sequence);
for (int i = 1; i < iters; i *= 2) { sum += g.shfl_xor(sum, i); }
sum = g.shfl(sum, threadIdx.x / tbSize);
}
sum += 1e-6;
for (int i = 0; i < iterations; i++) {
data[i].x /= sum;
data[i].y /= sum;
data[i].z /= sum;
data[i].w /= sum;
int data_id = i * iteration_stride + seq_lane;
if (data_id < seq_length) val_cast[data_offset + data_id] = data[i];
}
}
template <int tbSize, int blockStride, int tbSeq>
__global__ void attn_softmax(__half* vals,
const __half* attn_mask,
int heads,
int seq_length,
int iterations)
{
#ifdef HALF_PRECISION_AVAILABLE
__shared__ float partialSum[MAX_WARP_NUM];
int warp_num = blockDim.x >> WARP_SIZE_BITS;
int iteration_stride = blockDim.x;
int block_width = blockStride * seq_length;
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<tbSize> g = cg::tiled_partition<tbSize>(b);
int batch = blockIdx.y;
int row = blockIdx.x;
int max_threads_in_sequence = ::max(seq_length, tbSeq);
int seq_lane = threadIdx.x % max_threads_in_sequence;
int data_offset = batch * (gridDim.x * block_width) + row * block_width +
(threadIdx.x / max_threads_in_sequence) * seq_length;
int mask_offset = batch * seq_length;
int wid = threadIdx.x >> WARP_SIZE_BITS;
int lane = threadIdx.x & 0x1f;
float2* val_cast = reinterpret_cast<float2*>(vals);
const float2* attn_mask_cast = reinterpret_cast<const float2*>(attn_mask);
val_cast += data_offset;
attn_mask_cast += mask_offset;
float2 low_data[MAX_THREAD_ITERATIONS];
float2 high_data[MAX_THREAD_ITERATIONS];
float max_val = minus_infinity;
for (int i = 0; i < iterations; i++) {
int data_id = i * iteration_stride + seq_lane;
if (data_id < seq_length) {
float2 data = val_cast[data_id];
float2 mask = attn_mask_cast[data_id];
__half2* data_arr = reinterpret_cast<__half2*>(&data);
__half2* mask_arr = reinterpret_cast<__half2*>(&mask);
low_data[i] = __half22float2(data_arr[0]);
high_data[i] = __half22float2(data_arr[1]);
float2 low_mask = __half22float2(mask_arr[0]);
float2 high_mask = __half22float2(mask_arr[1]);
low_data[i].x += low_mask.x;
low_data[i].y += low_mask.y;
high_data[i].x += high_mask.x;
high_data[i].y += high_mask.y;
max_val = (low_data[i].x > max_val ? low_data[i].x : max_val);
max_val = (low_data[i].y > max_val ? low_data[i].y : max_val);
max_val = (high_data[i].x > max_val ? high_data[i].x : max_val);
max_val = (high_data[i].y > max_val ? high_data[i].y : max_val);
}
}
for (int i = 1; i < tbSize; i *= 2) {
auto temp = g.shfl_xor(max_val, i);
max_val = (temp > max_val ? temp : max_val);
}
if (seq_length > tbSize) {
if (lane == 0) partialSum[wid] = max_val;
b.sync();
if (lane < warp_num) max_val = partialSum[lane];
#ifndef __STOCHASTIC_MODE__
b.sync();
#endif
int iters = warp_num;
if (seq_length < iteration_stride)
iters = warp_num / (iteration_stride / max_threads_in_sequence);
for (int i = 1; i < iters; i *= 2) {
auto temp = g.shfl_xor(max_val, i);
max_val = (temp > max_val ? temp : max_val);
}
max_val = g.shfl(max_val, threadIdx.x / tbSize);
}
float sum = 0;
for (int i = 0; i < iterations; i++) {
int data_id = i * iteration_stride + seq_lane;
if (data_id < seq_length) {
low_data[i].x = __expf(low_data[i].x - max_val);
low_data[i].y = __expf(low_data[i].y - max_val);
high_data[i].x = __expf(high_data[i].x - max_val);
high_data[i].y = __expf(high_data[i].y - max_val);
sum += (low_data[i].x + low_data[i].y + high_data[i].x + high_data[i].y);
}
}
for (int i = 1; i < tbSize; i *= 2) { sum += g.shfl_xor(sum, i); }
if (seq_length > tbSize) {
if (lane == 0) partialSum[wid] = sum;
b.sync();
if (lane < warp_num) sum = partialSum[lane];
#ifndef __STOCHASTIC_MODE__
b.sync();
#endif
int iters = warp_num;
if (seq_length < iteration_stride)
iters = warp_num / (iteration_stride / max_threads_in_sequence);
for (int i = 1; i < iters; i *= 2) { sum += g.shfl_xor(sum, i); }
sum = g.shfl(sum, threadIdx.x / tbSize);
}
sum += 1e-6;
for (int i = 0; i < iterations; i++) {
int data_id = i * iteration_stride + seq_lane;
if (data_id < seq_length) {
float2 result_f;
__half2* result_h = reinterpret_cast<__half2*>(&result_f);
low_data[i].x /= sum;
low_data[i].y /= sum;
high_data[i].x /= sum;
high_data[i].y /= sum;
result_h[0] = __float22half2_rn(low_data[i]);
result_h[1] = __float22half2_rn(high_data[i]);
val_cast[data_id] = result_f;
}
}
#endif
}
template <typename T>
void launch_attn_softmax(T*, const T*, int, int, int, hipStream_t);
template <>
void launch_attn_softmax<float>(float* vals,
const float* attn_mask,
int batch_size,
int heads,
int sequence_length,
hipStream_t stream)
{
const int threads = 128;
int seq_length4 = sequence_length / 4;
dim3 grid_dim = get_attn_softmax_grid(batch_size, heads, sequence_length, threads);
int subblock_max_workload = MAX_THREAD_ITERATIONS * 4 * threads;
dim3 block_dim(seq_length4 > threads ? ((sequence_length + subblock_max_workload - 1) /
subblock_max_workload * threads)
: threads);
int iterations =
(sequence_length < subblock_max_workload ? (seq_length4 + threads - 1) / threads
: MAX_THREAD_ITERATIONS);
if (sequence_length <= 8)
hipLaunchKernelGGL(( attn_softmax<2, (threads / 2), 2>)
, dim3(grid_dim), dim3(block_dim), 0, stream, vals, attn_mask, heads, seq_length4, iterations);
else if (sequence_length <= 16)
hipLaunchKernelGGL(( attn_softmax<4, (threads / 4), 4>)
, dim3(grid_dim), dim3(block_dim), 0, stream, vals, attn_mask, heads, seq_length4, iterations);
else if (sequence_length <= 32)
hipLaunchKernelGGL(( attn_softmax<8, (threads / 8), 8>)
, dim3(grid_dim), dim3(block_dim), 0, stream, vals, attn_mask, heads, seq_length4, iterations);
else if (sequence_length <= 64)
hipLaunchKernelGGL(( attn_softmax<16, (threads / 16), 16>)
, dim3(grid_dim), dim3(block_dim), 0, stream, vals, attn_mask, heads, seq_length4, iterations);
else if (sequence_length <= 128)
hipLaunchKernelGGL(( attn_softmax<32, (threads / 32), 32>)
, dim3(grid_dim), dim3(block_dim), 0, stream, vals, attn_mask, heads, seq_length4, iterations);
else if (sequence_length <= 256)
hipLaunchKernelGGL(( attn_softmax<32, (threads / 64), 64>)
, dim3(grid_dim), dim3(block_dim), 0, stream, vals, attn_mask, heads, seq_length4, iterations);
else {
const int threads = 256;
dim3 grid_dim = get_attn_softmax_grid(batch_size, heads, sequence_length, threads);
int subblock_max_workload = MAX_THREAD_ITERATIONS * 4 * threads;
dim3 block_dim(seq_length4 > threads ? ((sequence_length + subblock_max_workload - 1) /
subblock_max_workload * threads)
: threads);
iterations =
(sequence_length < subblock_max_workload ? (seq_length4 + threads - 1) / threads
: MAX_THREAD_ITERATIONS);
if (sequence_length <= 512)
hipLaunchKernelGGL(( attn_softmax<32, (threads / 128), 128>), dim3(grid_dim), dim3(block_dim), 0, stream,
vals, attn_mask, heads, seq_length4, iterations);
else if (sequence_length < (MAX_THREADS * MAX_THREAD_ITERATIONS * 4))
hipLaunchKernelGGL(( attn_softmax<32, 1, 128>), dim3(grid_dim), dim3(block_dim), 0, stream,
vals, attn_mask, heads, seq_length4, iterations);
else
throw std::runtime_error(
"Unsupport Seq_Length! Check the restriction of the max_threads and "
"max_thread_iterations!");
}
}
template <>
void launch_attn_softmax<__half>(__half* vals,
const __half* attn_mask,
int batch_size,
int heads,
int sequence_length,
hipStream_t stream)
{
const int threads = 128;
int seq_length4 = sequence_length / 4;
dim3 grid_dim = get_attn_softmax_grid(batch_size, heads, sequence_length, threads);
int subblock_max_workload = MAX_THREAD_ITERATIONS * 4 * threads;
dim3 block_dim(seq_length4 > threads ? ((sequence_length + subblock_max_workload - 1) /
subblock_max_workload * threads)
: threads);
int iterations =
(sequence_length < subblock_max_workload ? (seq_length4 + threads - 1) / threads
: MAX_THREAD_ITERATIONS);
if (sequence_length <= 8)
hipLaunchKernelGGL(( attn_softmax<2, (threads / 2), 2>)
, dim3(grid_dim), dim3(block_dim), 0, stream, vals, attn_mask, heads, seq_length4, iterations);
else if (sequence_length <= 16)
hipLaunchKernelGGL(( attn_softmax<4, (threads / 4), 4>)
, dim3(grid_dim), dim3(block_dim), 0, stream, vals, attn_mask, heads, seq_length4, iterations);
else if (sequence_length <= 32)
hipLaunchKernelGGL(( attn_softmax<8, (threads / 8), 8>)
, dim3(grid_dim), dim3(block_dim), 0, stream, vals, attn_mask, heads, seq_length4, iterations);
else if (sequence_length <= 64)
hipLaunchKernelGGL(( attn_softmax<16, (threads / 16), 16>)
, dim3(grid_dim), dim3(block_dim), 0, stream, vals, attn_mask, heads, seq_length4, iterations);
else if (sequence_length <= 128)
hipLaunchKernelGGL(( attn_softmax<32, (threads / 32), 32>)
, dim3(grid_dim), dim3(block_dim), 0, stream, vals, attn_mask, heads, seq_length4, iterations);
else if (sequence_length <= 256)
hipLaunchKernelGGL(( attn_softmax<32, (threads / 64), 64>)
, dim3(grid_dim), dim3(block_dim), 0, stream, vals, attn_mask, heads, seq_length4, iterations);
else {
const int threads = 256;
dim3 grid_dim = get_attn_softmax_grid(batch_size, heads, sequence_length, threads);
int subblock_max_workload = MAX_THREAD_ITERATIONS * 4 * threads;
dim3 block_dim(seq_length4 > threads ? ((sequence_length + subblock_max_workload - 1) /
subblock_max_workload * threads)
: threads);
iterations =
(sequence_length < subblock_max_workload ? (seq_length4 + threads - 1) / threads
: MAX_THREAD_ITERATIONS);
if (sequence_length <= 512)
hipLaunchKernelGGL(( attn_softmax<32, (threads / 128), 128>), dim3(grid_dim), dim3(block_dim), 0, stream,
vals, attn_mask, heads, seq_length4, iterations);
else if (sequence_length < (MAX_THREADS * MAX_THREAD_ITERATIONS * 4))
hipLaunchKernelGGL(( attn_softmax<32, 1, 128>), dim3(grid_dim), dim3(block_dim), 0, stream,
vals, attn_mask, heads, seq_length4, iterations);
else
throw std::runtime_error(
"Unsupport Seq_Length! Check the restriction of the max_threads and "
"max_thread_iterations!");
}
}
template <typename T, int tbSize, int blockStride>
__global__ void softmax_backward_kernel(T* out_grad, const T* soft_inp, int seq_length)
{
__shared__ float partialSum[MAX_WARP_NUM];
int warp_num = blockDim.x >> WARP_SIZE_BITS; // warp-count = num_threads / WARP_SIZE (32)
int iteration_stride = blockDim.x;
int block_width = blockStride * seq_length;
int iterations = (seq_length < (MAX_THREAD_ITERATIONS * iteration_stride)
? (seq_length + iteration_stride - 1) / iteration_stride
: MAX_THREAD_ITERATIONS);
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<tbSize> g = cg::tiled_partition<tbSize>(b);
int row = blockIdx.x;
int id = threadIdx.x;
int wid = id >> WARP_SIZE_BITS;
int lane = id & 0x1f;
T val_reg[MAX_THREAD_ITERATIONS];
T soft_reg[MAX_THREAD_ITERATIONS];
float grad_reg = 0.0f;
#pragma unroll
for (int i = 0; i < iterations; i++) {
int data_id = i * iteration_stride + id;
if (data_id < block_width) {
val_reg[i] = out_grad[row * block_width + data_id];
soft_reg[i] = soft_inp[row * block_width + data_id];
grad_reg += ((float)val_reg[i] *
(float)soft_reg[i]); // if done in half, the multiplication, we may lose
// 2% of accuracy in computation!!
}
}
for (int i = 1; i < tbSize; i *= 2) grad_reg += g.shfl_xor(grad_reg, i);
if (seq_length > tbSize) {
if (lane == 0) partialSum[wid] = grad_reg;
b.sync();
if (lane < warp_num) grad_reg = partialSum[lane];
int iters = warp_num;
if (seq_length < iteration_stride) iters = warp_num / (iteration_stride / seq_length);
for (int i = 1; i < iters; i *= 2) grad_reg += g.shfl_xor(grad_reg, i);
grad_reg = g.shfl(grad_reg, id / tbSize);
}
for (int i = 0; i < iterations; i++) {
int data_id = i * iteration_stride + id;
if (data_id < block_width) {
float temp = (float)soft_reg[i] * ((float)val_reg[i] - grad_reg);
out_grad[row * block_width + data_id] = (T)temp;
}
}
}
template <typename T, int ITERATIONS>
__global__ void softmax_backward_kernel_v2(T* grad /* input & output*/,
const T* output,
int softmax_length)
{
int batch_idx = blockIdx.x * blockDim.y + threadIdx.y;
int offset = batch_idx * softmax_length + threadIdx.x;
grad += offset;
output += offset;
T grad_reg[ITERATIONS];
T output_reg[ITERATIONS];
float sum = 0.0;
#pragma unroll
for (int i = 0; i < ITERATIONS; ++i) {
int curr_idx = threadIdx.x + i * WARP_SIZE;
if (curr_idx < softmax_length) {
grad_reg[i] = grad[i * WARP_SIZE];
output_reg[i] = output[i * WARP_SIZE];
sum += (float)grad_reg[i] * (float)output_reg[i];
}
}
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<WARP_SIZE> g = cg::tiled_partition<WARP_SIZE>(b);
for (int i = 1; i < WARP_SIZE; i <<= 1) sum += g.shfl_xor(sum, i);
#pragma unroll
for (int i = 0; i < ITERATIONS; ++i) {
int curr_idx = threadIdx.x + i * WARP_SIZE;
if (curr_idx < softmax_length)
grad[i * WARP_SIZE] = (float)output_reg[i] * ((float)grad_reg[i] - sum);
}
}
template <typename T>
void launch_attn_softmax_backward_v2(T* out_grad,
const T* soft_inp,
int batch_size,
int heads,
int seq_length,
hipStream_t stream)
{
const int warps_per_block = 4;
dim3 grid_dim(batch_size * heads * seq_length / warps_per_block);
dim3 block_dim(WARP_SIZE, warps_per_block);
if (seq_length <= 32)
hipLaunchKernelGGL(( softmax_backward_kernel_v2<T, 1>)
, dim3(grid_dim), dim3(block_dim), 0, stream, out_grad, soft_inp, seq_length);
else if (seq_length <= 64)
hipLaunchKernelGGL(( softmax_backward_kernel_v2<T, 2>)
, dim3(grid_dim), dim3(block_dim), 0, stream, out_grad, soft_inp, seq_length);
else if (seq_length <= 128)
hipLaunchKernelGGL(( softmax_backward_kernel_v2<T, 4>)
, dim3(grid_dim), dim3(block_dim), 0, stream, out_grad, soft_inp, seq_length);
else if (seq_length <= 256)
hipLaunchKernelGGL(( softmax_backward_kernel_v2<T, 8>)
, dim3(grid_dim), dim3(block_dim), 0, stream, out_grad, soft_inp, seq_length);
else if (seq_length <= 384)
hipLaunchKernelGGL(( softmax_backward_kernel_v2<T, 12>)
, dim3(grid_dim), dim3(block_dim), 0, stream, out_grad, soft_inp, seq_length);
else if (seq_length <= 512)
hipLaunchKernelGGL(( softmax_backward_kernel_v2<T, 16>)
, dim3(grid_dim), dim3(block_dim), 0, stream, out_grad, soft_inp, seq_length);
else if (seq_length <= 768)
hipLaunchKernelGGL(( softmax_backward_kernel_v2<T, 24>)
, dim3(grid_dim), dim3(block_dim), 0, stream, out_grad, soft_inp, seq_length);
else if (seq_length <= 1024)
hipLaunchKernelGGL(( softmax_backward_kernel_v2<T, 32>)
, dim3(grid_dim), dim3(block_dim), 0, stream, out_grad, soft_inp, seq_length);
else if (seq_length <= 2048)
hipLaunchKernelGGL(( softmax_backward_kernel_v2<T, 64>)
, dim3(grid_dim), dim3(block_dim), 0, stream, out_grad, soft_inp, seq_length);
else
throw std::runtime_error(
std::string("Special sequence length found in softmax backward, seq_length: ") +
std::to_string(seq_length));
}
template void launch_attn_softmax_backward_v2<__half>(__half* out_grad,
const __half* soft_inp,
int batch_size,
int heads,
int seq_length,
hipStream_t stream);
template void launch_attn_softmax_backward_v2<float>(float* out_grad,
const float* soft_inp,
int batch_size,
int heads,
int seq_length,
hipStream_t stream);
csrc/transformer_bak/transform_kernels.cu
deleted
100644 → 0
View file @
7dd68788
#include "custom_cuda_layers.h"
#define rows_trans 16
#define cols_trans 16
template
<
typename
T
>
__global__
void
Transpose_Kernel
(
const
T
*
inp
,
T
*
out
,
int
row_width
,
int
col_width
)
{
__shared__
T
data_block
[
rows_trans
*
(
cols_trans
+
1
)];
int
r
=
threadIdx
.
x
/
cols_trans
;
int
c
=
threadIdx
.
x
%
cols_trans
;
int
m
=
row_width
/
cols_trans
;
int
i
=
blockIdx
.
x
/
m
*
rows_trans
+
r
;
int
j
=
blockIdx
.
x
%
m
*
cols_trans
+
c
;
int
row_stride
=
rows_trans
/
((
rows_trans
*
cols_trans
+
THREADS
-
1
)
/
THREADS
);
for
(
int
k
=
0
;
k
<
rows_trans
;
k
+=
row_stride
)
data_block
[(
k
+
r
)
*
cols_trans
+
c
]
=
inp
[(
i
+
k
)
*
row_width
+
j
];
__syncthreads
();
i
=
blockIdx
.
x
%
m
*
rows_trans
+
r
;
j
=
blockIdx
.
x
/
m
*
cols_trans
+
c
;
for
(
int
k
=
0
;
k
<
rows_trans
;
k
+=
row_stride
)
out
[(
i
+
k
)
*
col_width
+
j
]
=
data_block
[
c
*
cols_trans
+
r
+
k
];
}
template
<
>
void
Transpose
<
__half
>
(
const
__half
*
inp_mat
,
__half
*
out_mat
,
int
rows
,
int
cols
,
cudaStream_t
stream
)
{
int
threads
=
THREADS
;
Transpose_Kernel
<
__half
><<<
(
rows
*
cols
+
threads
-
1
)
/
threads
,
threads
,
0
,
stream
>>>
(
inp_mat
,
out_mat
,
cols
,
rows
);
}
template
<
>
void
Transpose
<
float
>
(
const
float
*
inp_mat
,
float
*
out_mat
,
int
rows
,
int
cols
,
cudaStream_t
stream
)
{
int
threads
=
THREADS
;
Transpose_Kernel
<
float
><<<
(
rows
*
cols
+
threads
-
1
)
/
threads
,
threads
,
0
,
stream
>>>
(
inp_mat
,
out_mat
,
cols
,
rows
);
}
template
<
typename
T
>
__global__
void
transform_0213
(
T
*
output
,
const
T
*
vals
,
int
hidden_dim
,
int
seq_length
,
int
heads
,
int
head_ext
);
template
<
>
__global__
void
transform_0213
<
float
>
(
float
*
output
,
const
float
*
vals
,
int
hidden_dim
,
int
seq_length
,
int
heads
,
int
head_ext
)
{
int
d0_stride
=
hidden_dim
*
seq_length
;
int
d1_stride
=
hidden_dim
;
int
d2_stride
=
hidden_dim
/
heads
;
int
d0_out_stride
=
d0_stride
;
int
d1_out_stride
=
d2_stride
;
int
d2_out_stride
=
d2_stride
*
seq_length
;
int
d0
=
blockIdx
.
x
;
// Batch
int
d1
=
blockIdx
.
y
/
head_ext
;
// Sequence ID (0-127)
int
d2
=
threadIdx
.
y
+
(
blockIdx
.
y
%
head_ext
)
*
(
heads
/
head_ext
);
// Head (0-11)
int
d3
=
threadIdx
.
x
;
// Values (groups of 4)
const
float4
*
vals_vec
=
reinterpret_cast
<
const
float4
*>
(
vals
);
float4
*
output_vec
=
reinterpret_cast
<
float4
*>
(
output
);
float4
inputs
=
vals_vec
[
d0
*
d0_stride
+
d1
*
d1_stride
+
d2
*
d2_stride
+
d3
];
output_vec
[
d0
*
d0_out_stride
+
d1
*
d1_out_stride
+
d2
*
d2_out_stride
+
d3
]
=
inputs
;
}
template
<
>
__global__
void
transform_0213
<
__half
>
(
__half
*
output
,
const
__half
*
vals
,
int
hidden_dim
,
int
seq_length
,
int
heads
,
int
head_ext
)
{
#ifdef HALF_PRECISION_AVAILABLE
int
d0_stride
=
hidden_dim
*
seq_length
;
int
d1_stride
=
hidden_dim
;
int
d2_stride
=
hidden_dim
/
heads
;
int
d0_out_stride
=
d0_stride
;
int
d1_out_stride
=
d2_stride
;
int
d2_out_stride
=
d2_stride
*
seq_length
;
int
d0
=
blockIdx
.
x
;
// Batch
int
d1
=
blockIdx
.
y
/
head_ext
;
// Sequence ID (0-127)
int
d2
=
threadIdx
.
y
+
(
blockIdx
.
y
%
head_ext
)
*
(
heads
/
head_ext
);
// Head (0-11)
int
d3
=
threadIdx
.
x
;
// Values (groups of 4)
float4
vals_arr
[
1
];
const
float4
*
vals_vec
=
reinterpret_cast
<
const
float4
*>
(
vals
);
float4
*
output_vec
=
reinterpret_cast
<
float4
*>
(
output
);
vals_arr
[
0
]
=
vals_vec
[
d0
*
d0_stride
+
d1
*
d1_stride
+
d2
*
d2_stride
+
d3
];
output_vec
[
d0
*
d0_out_stride
+
d1
*
d1_out_stride
+
d2
*
d2_out_stride
+
d3
]
=
vals_arr
[
0
];
#endif
}
template
<
>
void
launch_transform_0213
<
float
>
(
float
*
output
,
const
float
*
vals
,
int
batch_size
,
int
seq_length
,
int
hidden_dim
,
int
heads
,
cudaStream_t
stream
)
{
hidden_dim
>>=
2
;
int
head_ext
=
(
hidden_dim
-
1
)
/
MAX_THREADS
+
1
;
dim3
block_dim
(
hidden_dim
/
heads
,
(
heads
/
head_ext
));
dim3
grid_dim
(
batch_size
,
(
seq_length
*
head_ext
));
transform_0213
<
float
>
<<<
grid_dim
,
block_dim
,
0
,
stream
>>>
(
output
,
vals
,
hidden_dim
,
seq_length
,
heads
,
head_ext
);
}
template
<
>
void
launch_transform_0213
<
__half
>
(
__half
*
output
,
const
__half
*
vals
,
int
batch_size
,
int
seq_length
,
int
hidden_dim
,
int
heads
,
cudaStream_t
stream
)
{
hidden_dim
>>=
3
;
int
head_ext
=
(
hidden_dim
-
1
)
/
MAX_THREADS
+
1
;
dim3
block_dim
(
hidden_dim
/
heads
,
(
heads
/
head_ext
));
dim3
grid_dim
(
batch_size
,
(
seq_length
*
head_ext
));
transform_0213
<
__half
>
<<<
grid_dim
,
block_dim
,
0
,
stream
>>>
(
output
,
vals
,
hidden_dim
,
seq_length
,
heads
,
head_ext
);
}
// Bias add
template
<
typename
T
>
__global__
void
bias_add_transform_0213
(
T
*
output
,
const
T
*
vals
,
const
T
*
bias
,
int
hidden_dim
,
int
seq_length
,
int
heads
,
int
head_ext
);
template
<
>
__global__
void
bias_add_transform_0213
<
float
>
(
float
*
output
,
const
float
*
vals
,
const
float
*
bias
,
int
hidden_dim
,
int
seq_length
,
int
heads
,
int
head_ext
)
{
int
d0_stride
=
hidden_dim
*
seq_length
;
int
d1_stride
=
hidden_dim
;
int
d2_stride
=
hidden_dim
/
heads
;
int
d0_out_stride
=
d0_stride
;
int
d1_out_stride
=
d2_stride
;
int
d2_out_stride
=
d2_stride
*
seq_length
;
int
d0
=
blockIdx
.
x
;
// Batch
int
d1
=
blockIdx
.
y
;
// Sequence ID (0-127)
int
cnt
=
blockIdx
.
z
/
head_ext
;
// Hidden count
int
d2
=
threadIdx
.
y
+
(
blockIdx
.
z
%
head_ext
)
*
(
heads
/
head_ext
);
// Head (0-11)
int
d3
=
threadIdx
.
x
;
// Values (groups of 4)
const
float4
*
vals_vec
=
reinterpret_cast
<
const
float4
*>
(
vals
);
const
float4
*
bias_vec
=
reinterpret_cast
<
const
float4
*>
(
bias
);
float4
*
output_vec
=
reinterpret_cast
<
float4
*>
(
output
);
float4
inputs
=
vals_vec
[
d0
*
d0_stride
*
(
gridDim
.
z
/
head_ext
)
+
cnt
*
d1_stride
+
d1
*
d1_stride
*
(
gridDim
.
z
/
head_ext
)
+
d2
*
d2_stride
+
d3
];
float4
biases
=
bias_vec
[
cnt
*
d1_stride
+
d2
*
d2_stride
+
d3
];
float4
outputs
;
outputs
.
x
=
inputs
.
x
+
biases
.
x
;
outputs
.
y
=
inputs
.
y
+
biases
.
y
;
outputs
.
z
=
inputs
.
z
+
biases
.
z
;
outputs
.
w
=
inputs
.
w
+
biases
.
w
;
output_vec
[
cnt
*
d0_out_stride
*
gridDim
.
x
+
d0
*
d0_out_stride
+
d1
*
d1_out_stride
+
d2
*
d2_out_stride
+
d3
]
=
outputs
;
}
#define ATTN_H 3
#define MAX_SEQ_LINE 10
template
<
>
__global__
void
bias_add_transform_0213
<
__half
>
(
__half
*
output
,
const
__half
*
vals
,
const
__half
*
bias
,
int
hidden_dim
,
int
seq_length
,
int
heads
,
int
head_ext
)
{
#ifdef HALF_PRECISION_AVAILABLE
int
d0_stride
=
hidden_dim
*
seq_length
;
int
d1_stride
=
hidden_dim
;
int
d2_stride
=
hidden_dim
/
heads
;
int
d2_out_stride
=
d2_stride
*
seq_length
;
int
d0
=
blockIdx
.
x
;
// Batch
int
d1
=
blockIdx
.
y
;
// Sequence ID (0-127)
int
cnt
=
blockIdx
.
z
/
head_ext
;
// Hidden count
int
d2
=
threadIdx
.
y
+
(
blockIdx
.
z
%
head_ext
)
*
(
heads
/
head_ext
);
// Head (0-11)
int
d3
=
threadIdx
.
x
;
// Values (groups of 4)
float4
vals_arr
;
float4
bias_arr
;
float4
output_arr
;
__half2
*
vals_half
=
reinterpret_cast
<
__half2
*>
(
&
vals_arr
);
__half2
*
bias_half
=
reinterpret_cast
<
__half2
*>
(
&
bias_arr
);
__half2
*
output_half
=
reinterpret_cast
<
__half2
*>
(
&
output_arr
);
const
float4
*
vals_vec
=
reinterpret_cast
<
const
float4
*>
(
vals
);
const
float4
*
bias_vec
=
reinterpret_cast
<
const
float4
*>
(
bias
);
float4
*
output_vec
=
reinterpret_cast
<
float4
*>
(
output
);
vals_vec
+=
(
d0
*
d0_stride
*
(
gridDim
.
z
/
head_ext
));
vals_vec
+=
(
d1
*
d1_stride
*
(
gridDim
.
z
/
head_ext
));
vals_vec
+=
(
cnt
*
d1_stride
);
vals_vec
+=
(
d2
*
d2_stride
);
bias_vec
+=
(
cnt
*
d1_stride
);
bias_vec
+=
(
d2
*
d2_stride
);
output_vec
+=
(
cnt
*
d0_stride
*
gridDim
.
x
);
output_vec
+=
(
d1
*
d2_stride
);
output_vec
+=
(
d0
*
d0_stride
);
output_vec
+=
(
d2
*
d2_out_stride
);
bias_arr
=
bias_vec
[
d3
];
vals_arr
=
vals_vec
[
d3
];
#if defined(__ACC_HALF__)
output_half
[
0
]
=
vals_half
[
0
]
+
bias_half
[
0
];
output_half
[
1
]
=
vals_half
[
1
]
+
bias_half
[
1
];
output_half
[
2
]
=
vals_half
[
2
]
+
bias_half
[
2
];
output_half
[
3
]
=
vals_half
[
3
]
+
bias_half
[
3
];
#else
float2
bias_arr_f
[
4
];
float2
vals_arr_f
[
4
];
#pragma unroll
for
(
int
l
=
0
;
l
<
4
;
l
++
)
{
bias_arr_f
[
l
]
=
__half22float2
(
bias_half
[
l
]);
vals_arr_f
[
l
]
=
__half22float2
(
vals_half
[
l
]);
vals_arr_f
[
l
].
x
+=
bias_arr_f
[
l
].
x
;
vals_arr_f
[
l
].
y
+=
bias_arr_f
[
l
].
y
;
output_half
[
l
]
=
__float22half2_rn
(
vals_arr_f
[
l
]);
}
#endif
output_vec
[
d3
]
=
output_arr
;
#endif
}
__global__
void
bias_add_transform_0213_v2
(
__half
*
output
,
const
__half
*
vals
,
const
__half
*
bias
,
int
hidden_dim
,
int
seq_length
,
int
heads
)
{
#ifdef HALF_PRECISION_AVAILABLE
__shared__
float4
in_data
[
3072
];
int
d0_stride
=
hidden_dim
*
seq_length
;
int
d1_stride
=
hidden_dim
;
int
d2_stride
=
hidden_dim
/
heads
;
int
iteration_stride
=
d1_stride
*
blockDim
.
z
;
// Hidden * 3 / 8
int
batch_stride
=
d0_stride
*
blockDim
.
z
;
// Hidden * S * 3 / 8
int
d0_out_stride
=
d0_stride
;
int
d1_out_stride
=
d2_stride
;
int
d2_out_stride
=
d2_stride
*
seq_length
;
int
d0
=
blockIdx
.
x
;
// Batch
int
d1
=
blockIdx
.
y
;
// Sequence ID (0-127)
int
cnt
=
threadIdx
.
z
;
// blockIdx.z; // Hidden count
int
d2
=
threadIdx
.
y
;
// Head (0-11)
int
d3
=
threadIdx
.
x
;
// Values (groups of 4)
float4
vals_arr
[
1
];
float4
bias_arr
[
1
];
float4
output_arr
[
1
];
__half2
*
vals_half
=
reinterpret_cast
<
__half2
*>
(
vals_arr
);
__half2
*
bias_half
=
reinterpret_cast
<
__half2
*>
(
bias_arr
);
__half2
*
output_half
=
reinterpret_cast
<
__half2
*>
(
output_arr
);
const
float4
*
vals_vec
=
reinterpret_cast
<
const
float4
*>
(
vals
);
const
float4
*
bias_vec
=
reinterpret_cast
<
const
float4
*>
(
bias
);
float4
*
output_vec
=
reinterpret_cast
<
float4
*>
(
output
);
int
iter_index
=
cnt
*
d1_stride
+
d2
*
d2_stride
+
d3
;
int
input_offset
=
d0
*
batch_stride
+
d1
*
(
iteration_stride
<<
1
);
bias_arr
[
0
]
=
bias_vec
[
iter_index
];
#pragma unroll
for
(
int
iter
=
0
;
iter
<
2
;
iter
++
)
{
int
iter_id
=
iter
*
iteration_stride
+
iter_index
;
vals_arr
[
0
]
=
vals_vec
[
input_offset
+
iter_id
];
output_half
[
0
]
=
vals_half
[
0
]
+
bias_half
[
0
];
output_half
[
1
]
=
vals_half
[
1
]
+
bias_half
[
1
];
output_half
[
2
]
=
vals_half
[
2
]
+
bias_half
[
2
];
output_half
[
3
]
=
vals_half
[
3
]
+
bias_half
[
3
];
in_data
[
iter_id
]
=
output_arr
[
0
];
}
__syncthreads
();
iteration_stride
=
blockDim
.
z
*
(
blockDim
.
y
>>
1
);
int
matrix_stride
=
(
d0_out_stride
*
gridDim
.
x
);
int
head_count
=
(
d2
>>
1
)
+
cnt
*
(
blockDim
.
y
>>
1
);
int
out_index
=
d0
*
d0_out_stride
+
d1
*
(
d1_out_stride
<<
1
)
+
d3
+
(
d2
%
2
)
*
d2_stride
;
#pragma unroll
for
(
int
iter
=
0
;
iter
<
2
;
iter
++
)
{
int
iter_row
=
(
iter
*
iteration_stride
)
+
head_count
;
int
iter_offset
=
(
iter_row
%
blockDim
.
y
)
*
d2_out_stride
+
(
iter_row
/
blockDim
.
y
)
*
matrix_stride
;
output_vec
[
out_index
+
iter_offset
]
=
in_data
[
iter_row
*
d2_stride
+
d3
+
(
d2
%
2
)
*
(
d1_stride
*
blockDim
.
z
)];
}
#endif
}
// [B S C*H] - > C * [B A S N]
template
<
>
void
launch_bias_add_transform_0213
<
float
>
(
float
*
output
,
const
float
*
vals
,
const
float
*
bias
,
int
batch_size
,
int
seq_length
,
int
hidden_dim
,
int
heads
,
cudaStream_t
stream
,
int
trans_count
)
{
hidden_dim
>>=
2
;
int
head_ext
=
(
hidden_dim
-
1
)
/
MAX_THREADS
+
1
;
dim3
block_dim
(
hidden_dim
/
heads
,
(
heads
/
head_ext
));
dim3
grid_dim
(
batch_size
,
seq_length
,
(
trans_count
*
head_ext
));
bias_add_transform_0213
<
float
><<<
grid_dim
,
block_dim
,
0
,
stream
>>>
(
output
,
vals
,
bias
,
hidden_dim
,
seq_length
,
heads
,
head_ext
);
}
template
<
>
void
launch_bias_add_transform_0213
<
__half
>
(
__half
*
output
,
const
__half
*
vals
,
const
__half
*
bias
,
int
batch_size
,
int
seq_length
,
int
hidden_dim
,
int
heads
,
cudaStream_t
stream
,
int
trans_count
)
{
hidden_dim
>>=
3
;
if
(
hidden_dim
>
128
||
hidden_dim
<
16
)
{
int
head_ext
=
(
hidden_dim
-
1
)
/
MAX_THREADS
+
1
;
dim3
block_dim
(
hidden_dim
/
heads
,
(
heads
/
head_ext
));
dim3
grid_dim
(
batch_size
,
seq_length
,
(
trans_count
*
head_ext
));
bias_add_transform_0213
<
__half
><<<
grid_dim
,
block_dim
,
0
,
stream
>>>
(
output
,
vals
,
bias
,
hidden_dim
,
seq_length
,
heads
,
head_ext
);
}
else
{
dim3
block_dim
(
hidden_dim
/
heads
,
heads
,
trans_count
);
dim3
grid_dim
(
batch_size
,
seq_length
/
2
);
bias_add_transform_0213_v2
<<<
grid_dim
,
block_dim
,
0
,
stream
>>>
(
output
,
vals
,
bias
,
hidden_dim
,
seq_length
,
heads
);
}
}
template
<
typename
T
>
__global__
void
transform4d_0213
(
T
*
out
,
const
T
*
in
,
int
heads
,
int
seq_length
,
int
hidden_dim
,
int
head_ext
);
template
<
>
__global__
void
transform4d_0213
<
float
>
(
float
*
out
,
const
float
*
in
,
int
heads
,
int
seq_length
,
int
hidden_dim
,
int
head_ext
)
{
int
d0_stride
=
hidden_dim
*
seq_length
;
int
d1_stride
=
d0_stride
/
heads
;
int
d2_stride
=
hidden_dim
/
heads
;
int
d0_out_stride
=
d0_stride
;
int
d1_out_stride
=
d2_stride
;
int
d2_out_stride
=
hidden_dim
;
int
d0
=
blockIdx
.
x
;
// Batch
int
d1
=
blockIdx
.
y
/
((
seq_length
-
1
)
/
blockDim
.
y
+
1
);
// Head
int
d2
=
(
threadIdx
.
y
+
blockDim
.
y
*
blockIdx
.
y
)
%
seq_length
;
int
cnt
=
blockIdx
.
z
;
int
d3
=
threadIdx
.
x
;
// Values (groups of 8)
if
(
d2
<
seq_length
)
{
const
float4
*
in_vec
=
reinterpret_cast
<
const
float4
*>
(
in
);
float4
*
out_vec
=
reinterpret_cast
<
float4
*>
(
out
);
float4
vals_vec
=
in_vec
[
cnt
*
d0_stride
*
gridDim
.
x
+
d0
*
d0_stride
+
d1
*
d1_stride
+
d2
*
d2_stride
+
d3
];
out_vec
[
d0
*
d0_out_stride
*
gridDim
.
z
+
cnt
*
d2_out_stride
+
d1
*
d1_out_stride
+
d2
*
d2_out_stride
*
gridDim
.
z
+
d3
]
=
vals_vec
;
}
}
template
<
>
__global__
void
transform4d_0213
<
__half
>
(
__half
*
out
,
const
__half
*
in
,
int
heads
,
int
seq_length
,
int
hidden_dim
,
int
head_ext
)
{
#ifdef HALF_PRECISION_AVAILABLE
int
d0_stride
=
hidden_dim
*
(
seq_length
/
head_ext
);
int
d1_stride
=
hidden_dim
;
int
d2_stride
=
hidden_dim
/
heads
;
int
d0
=
blockIdx
.
x
;
// Batch
int
d1
=
threadIdx
.
y
+
(
blockIdx
.
z
%
head_ext
)
*
(
heads
/
head_ext
);
// Head
int
d2
=
blockIdx
.
z
/
head_ext
;
// Sequence
int
cnt
=
blockIdx
.
y
;
// Hidden count
int
d3
=
threadIdx
.
x
;
// Values (groups of 8)
const
float4
*
in_vec
=
reinterpret_cast
<
const
float4
*>
(
in
);
float4
*
out_vec
=
reinterpret_cast
<
float4
*>
(
out
);
in_vec
+=
(
cnt
*
d0_stride
*
gridDim
.
x
);
in_vec
+=
(
d0
*
d0_stride
);
in_vec
+=
(
d2
*
d2_stride
);
in_vec
+=
(
d1
*
d2_stride
*
seq_length
);
out_vec
+=
(
cnt
*
d1_stride
);
out_vec
+=
(
d1
*
d2_stride
);
out_vec
+=
(
d0
*
d0_stride
*
gridDim
.
y
);
out_vec
+=
(
d2
*
d1_stride
*
gridDim
.
y
);
out_vec
[
d3
]
=
in_vec
[
d3
];
#endif
}
__global__
void
transform4d_0213_v2
(
__half
*
out
,
const
__half
*
in
,
int
heads
,
int
seq_length
,
int
hidden_dim
)
{
#ifdef HALF_PRECISION_AVAILABLE
__shared__
float4
in_data
[
3072
];
int
d0_stride
=
hidden_dim
*
seq_length
;
int
d1_stride
=
hidden_dim
;
int
d2_stride
=
hidden_dim
/
heads
;
int
d0
=
blockIdx
.
x
;
// Batch
int
d1
=
threadIdx
.
y
;
// Head
int
d2
=
blockIdx
.
y
;
// Sequence
int
cnt
=
threadIdx
.
z
;
// Hidden count
int
d3
=
threadIdx
.
x
;
// Values (groups of 8)
const
float4
*
in_vec
=
reinterpret_cast
<
const
float4
*>
(
in
);
float4
*
out_vec
=
reinterpret_cast
<
float4
*>
(
out
);
int
input_offset
=
d0
*
d0_stride
+
d2
*
(
d2_stride
<<
1
)
+
d3
+
(
d1
%
2
)
*
d2_stride
;
int
head_count
=
(
d1
>>
1
)
+
cnt
*
(
blockDim
.
y
>>
1
);
int
iteration_stride
=
blockDim
.
z
*
(
blockDim
.
y
>>
1
);
int
matrix_stride
=
(
d0_stride
*
gridDim
.
x
);
#pragma unroll
for
(
int
iter
=
0
;
iter
<
2
;
iter
++
)
{
int
iter_row
=
iter
*
iteration_stride
+
head_count
;
int
iter_offset
=
(
iter_row
%
blockDim
.
y
)
*
d2_stride
;
in_data
[
d3
+
iter_offset
+
(
iter_row
/
blockDim
.
y
+
(
d1
%
2
)
*
blockDim
.
z
)
*
d1_stride
]
=
in_vec
[
input_offset
+
iter_offset
*
seq_length
+
(
iter_row
/
blockDim
.
y
)
*
matrix_stride
];
}
__syncthreads
();
iteration_stride
=
d1_stride
*
blockDim
.
z
;
int
iter_index
=
cnt
*
d1_stride
+
d1
*
d2_stride
+
d3
;
int
output_offset
=
d0
*
d0_stride
*
blockDim
.
z
+
d2
*
(
iteration_stride
<<
1
);
#pragma unroll
for
(
int
iter
=
0
;
iter
<
2
;
iter
++
)
{
int
iter_id
=
iter
*
iteration_stride
+
iter_index
;
out_vec
[
output_offset
+
iter_id
]
=
in_data
[
iter_id
];
}
#endif
}
// 3 * [B A S N] - > [B S C*H]
template
<
>
void
launch_transform4d_0213
<
float
>
(
float
*
out
,
const
float
*
in
,
int
batch_size
,
int
heads
,
int
seq_length
,
int
hidden_dim
,
cudaStream_t
stream
,
int
trans_count
)
{
hidden_dim
>>=
2
;
dim3
grid_dims
(
batch_size
,
heads
*
((
seq_length
-
1
)
/
8
+
1
),
trans_count
);
dim3
block_dims
(
hidden_dim
/
heads
,
8
);
transform4d_0213
<
float
>
<<<
grid_dims
,
block_dims
,
0
,
stream
>>>
(
out
,
in
,
heads
,
seq_length
,
hidden_dim
,
1
);
}
template
<
>
void
launch_transform4d_0213
<
__half
>
(
__half
*
out
,
const
__half
*
in
,
int
batch_size
,
int
heads
,
int
seq_length
,
int
hidden_dim
,
cudaStream_t
stream
,
int
trans_count
)
{
hidden_dim
>>=
3
;
if
(
hidden_dim
>
128
||
hidden_dim
<
16
)
{
int
head_ext
=
(
hidden_dim
-
1
)
/
MAX_THREADS
+
1
;
dim3
grid_dims
(
batch_size
,
trans_count
,
(
seq_length
*
head_ext
));
dim3
block_dims
(
hidden_dim
/
heads
,
(
heads
/
head_ext
));
transform4d_0213
<
__half
><<<
grid_dims
,
block_dims
,
0
,
stream
>>>
(
out
,
in
,
heads
,
seq_length
,
hidden_dim
,
head_ext
);
}
else
{
dim3
grid_dims
(
batch_size
,
seq_length
/
2
);
dim3
block_dims
(
hidden_dim
/
heads
,
heads
,
trans_count
);
transform4d_0213_v2
<<<
grid_dims
,
block_dims
,
0
,
stream
>>>
(
out
,
in
,
heads
,
seq_length
,
hidden_dim
);
}
}
csrc/transformer_bak/transform_kernels.hip
deleted
100644 → 0
View file @
7dd68788
// !!! This is a file automatically generated by hipify!!!
#include "hip/hip_runtime.h"
#include "custom_hip_layers.h"
#define rows_trans 16
#define cols_trans 16
template <typename T>
__global__ void Transpose_Kernel(const T* inp, T* out, int row_width, int col_width)
{
__shared__ T data_block[rows_trans * (cols_trans + 1)];
int r = threadIdx.x / cols_trans;
int c = threadIdx.x % cols_trans;
int m = row_width / cols_trans;
int i = blockIdx.x / m * rows_trans + r;
int j = blockIdx.x % m * cols_trans + c;
int row_stride = rows_trans / ((rows_trans * cols_trans + THREADS - 1) / THREADS);
for (int k = 0; k < rows_trans; k += row_stride)
data_block[(k + r) * cols_trans + c] = inp[(i + k) * row_width + j];
__syncthreads();
i = blockIdx.x % m * rows_trans + r;
j = blockIdx.x / m * cols_trans + c;
for (int k = 0; k < rows_trans; k += row_stride)
out[(i + k) * col_width + j] = data_block[c * cols_trans + r + k];
}
template <>
void Transpose<__half>(const __half* inp_mat,
__half* out_mat,
int rows,
int cols,
hipStream_t stream)
{
int threads = THREADS;
hipLaunchKernelGGL(( Transpose_Kernel<__half>), dim3((rows * cols + threads - 1) / threads), dim3(threads), 0, stream,
inp_mat, out_mat, cols, rows);
}
template <>
void Transpose<float>(const float* inp_mat, float* out_mat, int rows, int cols, hipStream_t stream)
{
int threads = THREADS;
hipLaunchKernelGGL(( Transpose_Kernel<float>), dim3((rows * cols + threads - 1) / threads), dim3(threads), 0, stream,
inp_mat, out_mat, cols, rows);
}
template <typename T>
__global__ void transform_0213(T* output,
const T* vals,
int hidden_dim,
int seq_length,
int heads,
int head_ext);
template <>
__global__ void transform_0213<float>(float* output,
const float* vals,
int hidden_dim,
int seq_length,
int heads,
int head_ext)
{
int d0_stride = hidden_dim * seq_length;
int d1_stride = hidden_dim;
int d2_stride = hidden_dim / heads;
int d0_out_stride = d0_stride;
int d1_out_stride = d2_stride;
int d2_out_stride = d2_stride * seq_length;
int d0 = blockIdx.x; // Batch
int d1 = blockIdx.y / head_ext; // Sequence ID (0-127)
int d2 = threadIdx.y + (blockIdx.y % head_ext) * (heads / head_ext); // Head (0-11)
int d3 = threadIdx.x; // Values (groups of 4)
const float4* vals_vec = reinterpret_cast<const float4*>(vals);
float4* output_vec = reinterpret_cast<float4*>(output);
float4 inputs = vals_vec[d0 * d0_stride + d1 * d1_stride + d2 * d2_stride + d3];
output_vec[d0 * d0_out_stride + d1 * d1_out_stride + d2 * d2_out_stride + d3] = inputs;
}
template <>
__global__ void transform_0213<__half>(__half* output,
const __half* vals,
int hidden_dim,
int seq_length,
int heads,
int head_ext)
{
#ifdef HALF_PRECISION_AVAILABLE
int d0_stride = hidden_dim * seq_length;
int d1_stride = hidden_dim;
int d2_stride = hidden_dim / heads;
int d0_out_stride = d0_stride;
int d1_out_stride = d2_stride;
int d2_out_stride = d2_stride * seq_length;
int d0 = blockIdx.x; // Batch
int d1 = blockIdx.y / head_ext; // Sequence ID (0-127)
int d2 = threadIdx.y + (blockIdx.y % head_ext) * (heads / head_ext); // Head (0-11)
int d3 = threadIdx.x; // Values (groups of 4)
float4 vals_arr[1];
const float4* vals_vec = reinterpret_cast<const float4*>(vals);
float4* output_vec = reinterpret_cast<float4*>(output);
vals_arr[0] = vals_vec[d0 * d0_stride + d1 * d1_stride + d2 * d2_stride + d3];
output_vec[d0 * d0_out_stride + d1 * d1_out_stride + d2 * d2_out_stride + d3] = vals_arr[0];
#endif
}
template <>
void launch_transform_0213<float>(float* output,
const float* vals,
int batch_size,
int seq_length,
int hidden_dim,
int heads,
hipStream_t stream)
{
hidden_dim >>= 2;
int head_ext = (hidden_dim - 1) / MAX_THREADS + 1;
dim3 block_dim(hidden_dim / heads, (heads / head_ext));
dim3 grid_dim(batch_size, (seq_length * head_ext));
hipLaunchKernelGGL(( transform_0213<float>)
, dim3(grid_dim), dim3(block_dim), 0, stream, output, vals, hidden_dim, seq_length, heads, head_ext);
}
template <>
void launch_transform_0213<__half>(__half* output,
const __half* vals,
int batch_size,
int seq_length,
int hidden_dim,
int heads,
hipStream_t stream)
{
hidden_dim >>= 3;
int head_ext = (hidden_dim - 1) / MAX_THREADS + 1;
dim3 block_dim(hidden_dim / heads, (heads / head_ext));
dim3 grid_dim(batch_size, (seq_length * head_ext));
hipLaunchKernelGGL(( transform_0213<__half>)
, dim3(grid_dim), dim3(block_dim), 0, stream, output, vals, hidden_dim, seq_length, heads, head_ext);
}
// Bias add
template <typename T>
__global__ void bias_add_transform_0213(T* output,
const T* vals,
const T* bias,
int hidden_dim,
int seq_length,
int heads,
int head_ext);
template <>
__global__ void bias_add_transform_0213<float>(float* output,
const float* vals,
const float* bias,
int hidden_dim,
int seq_length,
int heads,
int head_ext)
{
int d0_stride = hidden_dim * seq_length;
int d1_stride = hidden_dim;
int d2_stride = hidden_dim / heads;
int d0_out_stride = d0_stride;
int d1_out_stride = d2_stride;
int d2_out_stride = d2_stride * seq_length;
int d0 = blockIdx.x; // Batch
int d1 = blockIdx.y; // Sequence ID (0-127)
int cnt = blockIdx.z / head_ext; // Hidden count
int d2 = threadIdx.y + (blockIdx.z % head_ext) * (heads / head_ext); // Head (0-11)
int d3 = threadIdx.x; // Values (groups of 4)
const float4* vals_vec = reinterpret_cast<const float4*>(vals);
const float4* bias_vec = reinterpret_cast<const float4*>(bias);
float4* output_vec = reinterpret_cast<float4*>(output);
float4 inputs = vals_vec[d0 * d0_stride * (gridDim.z / head_ext) + cnt * d1_stride +
d1 * d1_stride * (gridDim.z / head_ext) + d2 * d2_stride + d3];
float4 biases = bias_vec[cnt * d1_stride + d2 * d2_stride + d3];
float4 outputs;
outputs.x = inputs.x + biases.x;
outputs.y = inputs.y + biases.y;
outputs.z = inputs.z + biases.z;
outputs.w = inputs.w + biases.w;
output_vec[cnt * d0_out_stride * gridDim.x + d0 * d0_out_stride + d1 * d1_out_stride +
d2 * d2_out_stride + d3] = outputs;
}
#define ATTN_H 3
#define MAX_SEQ_LINE 10
template <>
__global__ void bias_add_transform_0213<__half>(__half* output,
const __half* vals,
const __half* bias,
int hidden_dim,
int seq_length,
int heads,
int head_ext)
{
#ifdef HALF_PRECISION_AVAILABLE
int d0_stride = hidden_dim * seq_length;
int d1_stride = hidden_dim;
int d2_stride = hidden_dim / heads;
int d2_out_stride = d2_stride * seq_length;
int d0 = blockIdx.x; // Batch
int d1 = blockIdx.y; // Sequence ID (0-127)
int cnt = blockIdx.z / head_ext; // Hidden count
int d2 = threadIdx.y + (blockIdx.z % head_ext) * (heads / head_ext); // Head (0-11)
int d3 = threadIdx.x; // Values (groups of 4)
float4 vals_arr;
float4 bias_arr;
float4 output_arr;
__half2* vals_half = reinterpret_cast<__half2*>(&vals_arr);
__half2* bias_half = reinterpret_cast<__half2*>(&bias_arr);
__half2* output_half = reinterpret_cast<__half2*>(&output_arr);
const float4* vals_vec = reinterpret_cast<const float4*>(vals);
const float4* bias_vec = reinterpret_cast<const float4*>(bias);
float4* output_vec = reinterpret_cast<float4*>(output);
vals_vec += (d0 * d0_stride * (gridDim.z / head_ext));
vals_vec += (d1 * d1_stride * (gridDim.z / head_ext));
vals_vec += (cnt * d1_stride);
vals_vec += (d2 * d2_stride);
bias_vec += (cnt * d1_stride);
bias_vec += (d2 * d2_stride);
output_vec += (cnt * d0_stride * gridDim.x);
output_vec += (d1 * d2_stride);
output_vec += (d0 * d0_stride);
output_vec += (d2 * d2_out_stride);
bias_arr = bias_vec[d3];
vals_arr = vals_vec[d3];
#if defined(__ACC_HALF__)
output_half[0] = vals_half[0] + bias_half[0];
output_half[1] = vals_half[1] + bias_half[1];
output_half[2] = vals_half[2] + bias_half[2];
output_half[3] = vals_half[3] + bias_half[3];
#else
float2 bias_arr_f[4];
float2 vals_arr_f[4];
#pragma unroll
for (int l = 0; l < 4; l++) {
bias_arr_f[l] = __half22float2(bias_half[l]);
vals_arr_f[l] = __half22float2(vals_half[l]);
vals_arr_f[l].x += bias_arr_f[l].x;
vals_arr_f[l].y += bias_arr_f[l].y;
output_half[l] = __float22half2_rn(vals_arr_f[l]);
}
#endif
output_vec[d3] = output_arr;
#endif
}
__global__ void bias_add_transform_0213_v2(__half* output,
const __half* vals,
const __half* bias,
int hidden_dim,
int seq_length,
int heads)
{
#ifdef HALF_PRECISION_AVAILABLE
__shared__ float4 in_data[3072];
int d0_stride = hidden_dim * seq_length;
int d1_stride = hidden_dim;
int d2_stride = hidden_dim / heads;
int iteration_stride = d1_stride * blockDim.z; // Hidden * 3 / 8
int batch_stride = d0_stride * blockDim.z; // Hidden * S * 3 / 8
int d0_out_stride = d0_stride;
int d1_out_stride = d2_stride;
int d2_out_stride = d2_stride * seq_length;
int d0 = blockIdx.x; // Batch
int d1 = blockIdx.y; // Sequence ID (0-127)
int cnt = threadIdx.z; // blockIdx.z; // Hidden count
int d2 = threadIdx.y; // Head (0-11)
int d3 = threadIdx.x; // Values (groups of 4)
float4 vals_arr[1];
float4 bias_arr[1];
float4 output_arr[1];
__half2* vals_half = reinterpret_cast<__half2*>(vals_arr);
__half2* bias_half = reinterpret_cast<__half2*>(bias_arr);
__half2* output_half = reinterpret_cast<__half2*>(output_arr);
const float4* vals_vec = reinterpret_cast<const float4*>(vals);
const float4* bias_vec = reinterpret_cast<const float4*>(bias);
float4* output_vec = reinterpret_cast<float4*>(output);
int iter_index = cnt * d1_stride + d2 * d2_stride + d3;
int input_offset = d0 * batch_stride + d1 * (iteration_stride << 1);
bias_arr[0] = bias_vec[iter_index];
#pragma unroll
for (int iter = 0; iter < 2; iter++) {
int iter_id = iter * iteration_stride + iter_index;
vals_arr[0] = vals_vec[input_offset + iter_id];
output_half[0] = vals_half[0] + bias_half[0];
output_half[1] = vals_half[1] + bias_half[1];
output_half[2] = vals_half[2] + bias_half[2];
output_half[3] = vals_half[3] + bias_half[3];
in_data[iter_id] = output_arr[0];
}
__syncthreads();
iteration_stride = blockDim.z * (blockDim.y >> 1);
int matrix_stride = (d0_out_stride * gridDim.x);
int head_count = (d2 >> 1) + cnt * (blockDim.y >> 1);
int out_index = d0 * d0_out_stride + d1 * (d1_out_stride << 1) + d3 + (d2 % 2) * d2_stride;
#pragma unroll
for (int iter = 0; iter < 2; iter++) {
int iter_row = (iter * iteration_stride) + head_count;
int iter_offset =
(iter_row % blockDim.y) * d2_out_stride + (iter_row / blockDim.y) * matrix_stride;
output_vec[out_index + iter_offset] =
in_data[iter_row * d2_stride + d3 + (d2 % 2) * (d1_stride * blockDim.z)];
}
#endif
}
// [B S C*H] - > C * [B A S N]
template <>
void launch_bias_add_transform_0213<float>(float* output,
const float* vals,
const float* bias,
int batch_size,
int seq_length,
int hidden_dim,
int heads,
hipStream_t stream,
int trans_count)
{
hidden_dim >>= 2;
int head_ext = (hidden_dim - 1) / MAX_THREADS + 1;
dim3 block_dim(hidden_dim / heads, (heads / head_ext));
dim3 grid_dim(batch_size, seq_length, (trans_count * head_ext));
hipLaunchKernelGGL(( bias_add_transform_0213<float>), dim3(grid_dim), dim3(block_dim), 0, stream,
output, vals, bias, hidden_dim, seq_length, heads, head_ext);
}
template <>
void launch_bias_add_transform_0213<__half>(__half* output,
const __half* vals,
const __half* bias,
int batch_size,
int seq_length,
int hidden_dim,
int heads,
hipStream_t stream,
int trans_count)
{
hidden_dim >>= 3;
if (hidden_dim > 128 || hidden_dim < 16) {
int head_ext = (hidden_dim - 1) / MAX_THREADS + 1;
dim3 block_dim(hidden_dim / heads, (heads / head_ext));
dim3 grid_dim(batch_size, seq_length, (trans_count * head_ext));
hipLaunchKernelGGL(( bias_add_transform_0213<__half>), dim3(grid_dim), dim3(block_dim), 0, stream,
output, vals, bias, hidden_dim, seq_length, heads, head_ext);
} else {
dim3 block_dim(hidden_dim / heads, heads, trans_count);
dim3 grid_dim(batch_size, seq_length / 2);
hipLaunchKernelGGL(( bias_add_transform_0213_v2), dim3(grid_dim), dim3(block_dim), 0, stream,
output, vals, bias, hidden_dim, seq_length, heads);
}
}
template <typename T>
__global__ void transform4d_0213(T* out,
const T* in,
int heads,
int seq_length,
int hidden_dim,
int head_ext);
template <>
__global__ void transform4d_0213<float>(float* out,
const float* in,
int heads,
int seq_length,
int hidden_dim,
int head_ext)
{
int d0_stride = hidden_dim * seq_length;
int d1_stride = d0_stride / heads;
int d2_stride = hidden_dim / heads;
int d0_out_stride = d0_stride;
int d1_out_stride = d2_stride;
int d2_out_stride = hidden_dim;
int d0 = blockIdx.x; // Batch
int d1 = blockIdx.y / ((seq_length - 1) / blockDim.y + 1); // Head
int d2 = (threadIdx.y + blockDim.y * blockIdx.y) % seq_length;
int cnt = blockIdx.z;
int d3 = threadIdx.x; // Values (groups of 8)
if (d2 < seq_length) {
const float4* in_vec = reinterpret_cast<const float4*>(in);
float4* out_vec = reinterpret_cast<float4*>(out);
float4 vals_vec = in_vec[cnt * d0_stride * gridDim.x + d0 * d0_stride + d1 * d1_stride +
d2 * d2_stride + d3];
out_vec[d0 * d0_out_stride * gridDim.z + cnt * d2_out_stride + d1 * d1_out_stride +
d2 * d2_out_stride * gridDim.z + d3] = vals_vec;
}
}
template <>
__global__ void transform4d_0213<__half>(__half* out,
const __half* in,
int heads,
int seq_length,
int hidden_dim,
int head_ext)
{
#ifdef HALF_PRECISION_AVAILABLE
int d0_stride = hidden_dim * (seq_length / head_ext);
int d1_stride = hidden_dim;
int d2_stride = hidden_dim / heads;
int d0 = blockIdx.x; // Batch
int d1 = threadIdx.y + (blockIdx.z % head_ext) * (heads / head_ext); // Head
int d2 = blockIdx.z / head_ext; // Sequence
int cnt = blockIdx.y; // Hidden count
int d3 = threadIdx.x; // Values (groups of 8)
const float4* in_vec = reinterpret_cast<const float4*>(in);
float4* out_vec = reinterpret_cast<float4*>(out);
in_vec += (cnt * d0_stride * gridDim.x);
in_vec += (d0 * d0_stride);
in_vec += (d2 * d2_stride);
in_vec += (d1 * d2_stride * seq_length);
out_vec += (cnt * d1_stride);
out_vec += (d1 * d2_stride);
out_vec += (d0 * d0_stride * gridDim.y);
out_vec += (d2 * d1_stride * gridDim.y);
out_vec[d3] = in_vec[d3];
#endif
}
__global__ void transform4d_0213_v2(__half* out,
const __half* in,
int heads,
int seq_length,
int hidden_dim)
{
#ifdef HALF_PRECISION_AVAILABLE
__shared__ float4 in_data[3072];
int d0_stride = hidden_dim * seq_length;
int d1_stride = hidden_dim;
int d2_stride = hidden_dim / heads;
int d0 = blockIdx.x; // Batch
int d1 = threadIdx.y; // Head
int d2 = blockIdx.y; // Sequence
int cnt = threadIdx.z; // Hidden count
int d3 = threadIdx.x; // Values (groups of 8)
const float4* in_vec = reinterpret_cast<const float4*>(in);
float4* out_vec = reinterpret_cast<float4*>(out);
int input_offset = d0 * d0_stride + d2 * (d2_stride << 1) + d3 + (d1 % 2) * d2_stride;
int head_count = (d1 >> 1) + cnt * (blockDim.y >> 1);
int iteration_stride = blockDim.z * (blockDim.y >> 1);
int matrix_stride = (d0_stride * gridDim.x);
#pragma unroll
for (int iter = 0; iter < 2; iter++) {
int iter_row = iter * iteration_stride + head_count;
int iter_offset = (iter_row % blockDim.y) * d2_stride;
in_data[d3 + iter_offset + (iter_row / blockDim.y + (d1 % 2) * blockDim.z) * d1_stride] =
in_vec[input_offset + iter_offset * seq_length +
(iter_row / blockDim.y) * matrix_stride];
}
__syncthreads();
iteration_stride = d1_stride * blockDim.z;
int iter_index = cnt * d1_stride + d1 * d2_stride + d3;
int output_offset = d0 * d0_stride * blockDim.z + d2 * (iteration_stride << 1);
#pragma unroll
for (int iter = 0; iter < 2; iter++) {
int iter_id = iter * iteration_stride + iter_index;
out_vec[output_offset + iter_id] = in_data[iter_id];
}
#endif
}
// 3 * [B A S N] - > [B S C*H]
template <>
void launch_transform4d_0213<float>(float* out,
const float* in,
int batch_size,
int heads,
int seq_length,
int hidden_dim,
hipStream_t stream,
int trans_count)
{
hidden_dim >>= 2;
dim3 grid_dims(batch_size, heads * ((seq_length - 1) / 8 + 1), trans_count);
dim3 block_dims(hidden_dim / heads, 8);
hipLaunchKernelGGL(( transform4d_0213<float>)
, dim3(grid_dims), dim3(block_dims), 0, stream, out, in, heads, seq_length, hidden_dim, 1);
}
template <>
void launch_transform4d_0213<__half>(__half* out,
const __half* in,
int batch_size,
int heads,
int seq_length,
int hidden_dim,
hipStream_t stream,
int trans_count)
{
hidden_dim >>= 3;
if (hidden_dim > 128 || hidden_dim < 16) {
int head_ext = (hidden_dim - 1) / MAX_THREADS + 1;
dim3 grid_dims(batch_size, trans_count, (seq_length * head_ext));
dim3 block_dims(hidden_dim / heads, (heads / head_ext));
hipLaunchKernelGGL(( transform4d_0213<__half>), dim3(grid_dims), dim3(block_dims), 0, stream,
out, in, heads, seq_length, hidden_dim, head_ext);
} else {
dim3 grid_dims(batch_size, seq_length / 2);
dim3 block_dims(hidden_dim / heads, heads, trans_count);
hipLaunchKernelGGL(( transform4d_0213_v2), dim3(grid_dims), dim3(block_dims), 0, stream,
out, in, heads, seq_length, hidden_dim);
}
}
Prev
1
2
3
4
5
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment