Unverified Commit 771d4b83 authored by shenggan's avatar shenggan Committed by GitHub
Browse files

use template in layernorm kernel & add unittest for fastnn layernorm (#25)

parent ad1bbc52
...@@ -58,78 +58,9 @@ __inline__ __device__ void WelfordWarpAllReduce(float thread_mean, float thread_ ...@@ -58,78 +58,9 @@ __inline__ __device__ void WelfordWarpAllReduce(float thread_mean, float thread_
*count = __shfl_sync(0xffffffff, *count, 0, 32); *count = __shfl_sync(0xffffffff, *count, 0, 32);
} }
__global__ void fastfold_layernorm_fp32(float* input, float* output, float* gamma, float* beta, template <typename T>
float* mean, float* invvar, int rows, int cols, __global__ void fastfold_layernorm(T* input, T* output, T* gamma, T* beta, float* mean,
double epsilon) { float* invvar, int rows, int cols, double epsilon) {
int threadidx_x = threadIdx.x / 32;
int threadidx_y = threadIdx.x % 32;
int row_offset = blockIdx.x * 4 + threadidx_x;
int cols_per_thread = (cols + 31) / 32;
int cols_this_thread = cols_per_thread;
int last_y = (cols / cols_per_thread);
if (threadidx_y == last_y) {
cols_this_thread = cols - cols_per_thread * last_y;
}
else if (threadidx_y > last_y) {
cols_this_thread = 0;
}
int lane_id = threadidx_y;
if (row_offset < rows) {
float buf[32];
float thread_mean = 0.f;
float thread_m2 = 0.f;
float thread_count = 0.f;
float warp_mean;
float warp_m2;
float warp_count;
float* row_input = input + row_offset * cols;
float* row_output = output + row_offset * cols;
#pragma unroll
for (int i = 0; i < cols_this_thread; i++) {
buf[i] = row_input[lane_id * cols_per_thread + i];
}
#pragma unroll
for (int i = 0; i < cols_this_thread; i++) {
WelfordOnline(buf[i], &thread_mean, &thread_m2, &thread_count);
}
WelfordWarpAllReduce(thread_mean, thread_m2, thread_count, &warp_mean, &warp_m2, &warp_count);
float row_mean = warp_mean;
float row_variance = max(warp_m2 / warp_count, 0.f);
float row_inv_var = rsqrt(row_variance + epsilon);
if (lane_id == 0) {
mean[row_offset] = row_mean;
invvar[row_offset] = row_inv_var;
}
#pragma unroll
for (int i = 0; i < cols_this_thread; ++i) {
buf[i] = (buf[i] - row_mean) * row_inv_var;
}
#pragma unroll
for (int i = 0; i < cols_this_thread; ++i) {
row_output[lane_id * cols_per_thread + i] =
buf[i] * gamma[lane_id * cols_per_thread + i] + beta[lane_id * cols_per_thread + i];
}
}
}
__global__ void fastfold_layernorm_bfp16(at::BFloat16* input, at::BFloat16* output,
at::BFloat16* gamma, at::BFloat16* beta, float* mean,
float* invvar, int rows, int cols, double epsilon) {
int threadidx_x = threadIdx.x / 32; int threadidx_x = threadIdx.x / 32;
int threadidx_y = threadIdx.x % 32; int threadidx_y = threadIdx.x % 32;
int row_offset = blockIdx.x * 4 + threadidx_x; int row_offset = blockIdx.x * 4 + threadidx_x;
...@@ -140,15 +71,13 @@ __global__ void fastfold_layernorm_bfp16(at::BFloat16* input, at::BFloat16* outp ...@@ -140,15 +71,13 @@ __global__ void fastfold_layernorm_bfp16(at::BFloat16* input, at::BFloat16* outp
if (threadidx_y == last_y) { if (threadidx_y == last_y) {
cols_this_thread = cols - cols_per_thread * last_y; cols_this_thread = cols - cols_per_thread * last_y;
} } else if (threadidx_y > last_y) {
else if (threadidx_y > last_y) {
cols_this_thread = 0; cols_this_thread = 0;
} }
int lane_id = threadidx_y; int lane_id = threadidx_y;
if (row_offset < rows) { if (row_offset < rows) {
float buf[32]; float buf[32];
float thread_mean = 0.f; float thread_mean = 0.f;
...@@ -159,20 +88,21 @@ __global__ void fastfold_layernorm_bfp16(at::BFloat16* input, at::BFloat16* outp ...@@ -159,20 +88,21 @@ __global__ void fastfold_layernorm_bfp16(at::BFloat16* input, at::BFloat16* outp
float warp_m2; float warp_m2;
float warp_count; float warp_count;
at::BFloat16* row_input = input + row_offset * cols; T* row_input = input + row_offset * cols;
at::BFloat16* row_output = output + row_offset * cols; T* row_output = output + row_offset * cols;
#pragma unroll #pragma unroll
for (int i = 0; i < cols_this_thread; i++) { for (int i = 0; i < cols_this_thread; i++) {
buf[i] = static_cast<float>(row_input[lane_id * cols_per_thread + i]); buf[i] = static_cast<float>(row_input[lane_id * cols_per_thread + i]);
} }
#pragma unroll #pragma unroll
for (int i = 0; i < cols_this_thread; i++) { for (int i = 0; i < cols_this_thread; i++) {
WelfordOnline(buf[i], &thread_mean, &thread_m2, &thread_count); WelfordOnline(buf[i], &thread_mean, &thread_m2, &thread_count);
} }
WelfordWarpAllReduce(thread_mean, thread_m2, thread_count, &warp_mean, &warp_m2, &warp_count); WelfordWarpAllReduce(thread_mean, thread_m2, thread_count, &warp_mean, &warp_m2,
&warp_count);
float row_mean = warp_mean; float row_mean = warp_mean;
float row_variance = max(warp_m2 / warp_count, 0.f); float row_variance = max(warp_m2 / warp_count, 0.f);
...@@ -183,15 +113,15 @@ __global__ void fastfold_layernorm_bfp16(at::BFloat16* input, at::BFloat16* outp ...@@ -183,15 +113,15 @@ __global__ void fastfold_layernorm_bfp16(at::BFloat16* input, at::BFloat16* outp
invvar[row_offset] = row_inv_var; invvar[row_offset] = row_inv_var;
} }
#pragma unroll #pragma unroll
for (int i = 0; i < cols_this_thread; ++i) { for (int i = 0; i < cols_this_thread; ++i) {
buf[i] = (buf[i] - row_mean) * row_inv_var; buf[i] = (buf[i] - row_mean) * row_inv_var;
} }
#pragma unroll #pragma unroll
for (int i = 0; i < cols_this_thread; ++i) { for (int i = 0; i < cols_this_thread; ++i) {
row_output[lane_id * cols_per_thread + i] = row_output[lane_id * cols_per_thread + i] =
static_cast<at::BFloat16>(buf[i]) * gamma[lane_id * cols_per_thread + i] + static_cast<T>(buf[i]) * gamma[lane_id * cols_per_thread + i] +
beta[lane_id * cols_per_thread + i]; beta[lane_id * cols_per_thread + i];
} }
} }
...@@ -204,12 +134,17 @@ void cuda_layer_norm(at::Tensor* output, at::Tensor* mean, at::Tensor* invvar, a ...@@ -204,12 +134,17 @@ void cuda_layer_norm(at::Tensor* output, at::Tensor* mean, at::Tensor* invvar, a
dim3 block(128); dim3 block(128);
if (output->dtype() == torch::kFloat32) { if (output->dtype() == torch::kFloat32) {
fastfold_layernorm_fp32<<<grid, block>>>( fastfold_layernorm<float><<<grid, block>>>(
(float*)input->data_ptr(), (float*)output->data_ptr(), (float*)gamma->data_ptr(), (float*)input->data_ptr(), (float*)output->data_ptr(), (float*)gamma->data_ptr(),
(float*)beta->data_ptr(), (float*)mean->data_ptr(), (float*)invvar->data_ptr(), rows, (float*)beta->data_ptr(), (float*)mean->data_ptr(), (float*)invvar->data_ptr(), rows,
cols, epsilon); cols, epsilon);
} else { } else if (output->dtype() == torch::kFloat16) {
fastfold_layernorm_bfp16<<<grid, block>>>( fastfold_layernorm<at::Half><<<grid, block>>>(
(at::Half*)input->data_ptr(), (at::Half*)output->data_ptr(),
(at::Half*)gamma->data_ptr(), (at::Half*)beta->data_ptr(), (float*)mean->data_ptr(),
(float*)invvar->data_ptr(), rows, cols, epsilon);
} else if (output->dtype() == torch::kBFloat16) {
fastfold_layernorm<at::BFloat16><<<grid, block>>>(
(at::BFloat16*)input->data_ptr(), (at::BFloat16*)output->data_ptr(), (at::BFloat16*)input->data_ptr(), (at::BFloat16*)output->data_ptr(),
(at::BFloat16*)gamma->data_ptr(), (at::BFloat16*)beta->data_ptr(), (at::BFloat16*)gamma->data_ptr(), (at::BFloat16*)beta->data_ptr(),
(float*)mean->data_ptr(), (float*)invvar->data_ptr(), rows, cols, epsilon); (float*)mean->data_ptr(), (float*)invvar->data_ptr(), rows, cols, epsilon);
......
import torch
from fastfold.model.fastnn.kernel import LayerNorm as FastLayerNorm
def test_layernorm():
# [batch, dim]
test_shape = [[64, 64], [64, 128], [64, 129], [64, 1024]]
test_dtype = [torch.float32, torch.float16, torch.bfloat16]
test_device = torch.device("cuda")
tolerance_eps = {torch.float32: 10e-5, torch.float16: 10e-2, torch.bfloat16: 10e-2}
for shape in test_shape:
for dtype in test_dtype:
sample_input = torch.rand(shape).to(device=test_device,
dtype=dtype).requires_grad_(False)
dim_ = sample_input.size()[-1]
torch_module = torch.nn.LayerNorm(normalized_shape=dim_).to(device=test_device,
dtype=dtype)
fastnn_module = FastLayerNorm(normalized_shape=dim_).to(device=test_device, dtype=dtype)
# Forward
torch_out = torch_module(sample_input)
fastnn_out = fastnn_module(sample_input)
forward_error = torch.max(torch.abs(torch_out - fastnn_out)).cpu().item()
assert forward_error < tolerance_eps[dtype], f"Error when {shape} {dtype}"
# Backward
out_grad = torch.rand_like(torch_out).requires_grad_(False)
torch_out.backward(out_grad)
fastnn_out.backward(out_grad)
backward_weight_error = torch.max(
torch.abs(torch_module.weight.grad - fastnn_module.weight.grad)).cpu().item()
assert backward_weight_error < tolerance_eps[dtype], f"Error when {shape} {dtype}"
backward_bias_error = torch.max(
torch.abs(torch_module.bias.grad - fastnn_module.bias.grad)).cpu().item()
assert backward_bias_error < tolerance_eps[dtype], f"Error when {shape} {dtype}"
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment