Unverified Commit 9654931c authored by Sangkug Lym's avatar Sangkug Lym Committed by GitHub
Browse files

Support vectorized local reduction for p2p-based ReduceScatter overlap (#1452)



* Support vectorized local reduction for p2p-based ReduceScatter overlap
Signed-off-by: default avatarSangkug Lym <slym@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* cleanup
Signed-off-by: default avatarSangkug Lym <slym@nvidia.com>

---------
Signed-off-by: default avatarSangkug Lym <slym@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent eb28c650
......@@ -9,10 +9,9 @@
#include <cuda_runtime.h>
#if __CUDA_ARCH__ >= 800
#include <cuda_bf16.h>
#define half nv_bfloat16
#define half_dtype nv_bfloat16
#else
#include <cuda_fp16.h>
#define half_dtype half
#endif
#include <assert.h>
......@@ -20,6 +19,7 @@
#include <unistd.h>
#include "common/util/system.h"
#include "common/util/vectorized_pointwise.h"
#include "userbuffers.h"
#define MAX_THREADS 1024
......@@ -116,11 +116,11 @@ __global__ void __launch_bounds__(MAX_THREADS)
}
int4 sum = val[0];
half *s = reinterpret_cast<half *>(&sum);
half_dtype *s = reinterpret_cast<half_dtype *>(&sum);
#pragma unroll
for (int i = 1; i < RANKS; i++) {
half *x = reinterpret_cast<half *>(&val[i]);
half_dtype *x = reinterpret_cast<half_dtype *>(&val[i]);
#pragma unroll
for (int j = 0; j < 8; j++) s[j] += x[j];
}
......@@ -200,11 +200,11 @@ __global__ void __launch_bounds__(MAX_THREADS)
}
int4 sum = val[0];
half *s = reinterpret_cast<half *>(&sum);
half_dtype *s = reinterpret_cast<half_dtype *>(&sum);
#pragma unroll
for (int i = 1; i < RANKS; i++) {
half *x = reinterpret_cast<half *>(&val[i]);
half_dtype *x = reinterpret_cast<half_dtype *>(&val[i]);
#pragma unroll
for (int j = 0; j < 8; j++) s[j] += x[j];
}
......@@ -311,11 +311,11 @@ __global__ void __launch_bounds__(MAX_THREADS)
}
int4 sum = val[0];
half *s = reinterpret_cast<half *>(&sum);
half_dtype *s = reinterpret_cast<half_dtype *>(&sum);
#pragma unroll
for (int i = 1; i < RANKS; i++) {
half *x = reinterpret_cast<half *>(&val[i]);
half_dtype *x = reinterpret_cast<half_dtype *>(&val[i]);
#pragma unroll
for (int j = 0; j < 8; j++) s[j] += x[j];
}
......@@ -378,11 +378,11 @@ __global__ void __launch_bounds__(MAX_THREADS) userbuffers_fp16_sum_inplace_gpu_
}
int4 sum = val[0];
half *s = reinterpret_cast<half *>(&sum);
half_dtype *s = reinterpret_cast<half_dtype *>(&sum);
#pragma unroll
for (int i = 1; i < RANKS; i++) {
half *x = reinterpret_cast<half *>(&val[i]);
half_dtype *x = reinterpret_cast<half_dtype *>(&val[i]);
#pragma unroll
for (int j = 0; j < 8; j++) s[j] += x[j];
}
......@@ -780,7 +780,7 @@ __global__ void __launch_bounds__(MAX_THREADS) userbuffers_fp16_sum_inplace_gpu_
int physgpu, targetgpu, *myptr;
int *reduceidptr, reduce_id;
int lastSM = 0;
half hscale = (half)*scale;
half_dtype hscale = (half_dtype)*scale;
if (threadIdx.x < RANKS) {
physgpu = myrank * gpustep + firstrank;
......@@ -823,13 +823,13 @@ __global__ void __launch_bounds__(MAX_THREADS) userbuffers_fp16_sum_inplace_gpu_
}
int4 sum[2] = {{0, 0, 0, 0}, {0, 0, 0, 0}};
half *s = reinterpret_cast<half *>(&sum);
half_dtype *s = reinterpret_cast<half_dtype *>(&sum);
#pragma unroll
for (int i = 0; i < RANKS; i++) {
fp8type *x = reinterpret_cast<fp8type *>(&val[i]);
#pragma unroll
for (int j = 0; j < sizeof(int4) / sizeof(fp8type); j++) s[j] += hscale * (half)(x[j]);
for (int j = 0; j < sizeof(int4) / sizeof(fp8type); j++) s[j] += hscale * (half_dtype)(x[j]);
}
int hline = 2 * line;
(reinterpret_cast<int4 *>(outbuf))[(hline / rowlines) * skiplines + (hline % rowlines)] =
......@@ -855,7 +855,7 @@ __global__ void __launch_bounds__(MAX_THREADS)
int physgpu, targetgpu, *myptr;
int *reduceidptr, reduce_id;
int lastSM = 0;
half hscale = (half)*scale;
half_dtype hscale = (half_dtype)*scale;
if (threadIdx.x < RANKS) {
physgpu = myrank * gpustep + firstrank;
......@@ -919,13 +919,14 @@ __global__ void __launch_bounds__(MAX_THREADS)
}
int4 sum[2] = {{0, 0, 0, 0}, {0, 0, 0, 0}};
half *s = reinterpret_cast<half *>(&sum);
half_dtype *s = reinterpret_cast<half_dtype *>(&sum);
#pragma unroll
for (int i = 0; i < RANKS; i++) {
fp8type *x = reinterpret_cast<fp8type *>(&val[i]);
#pragma unroll
for (int j = 0; j < sizeof(int4) / sizeof(fp8type); j++) s[j] += hscale * (half)(x[j]);
for (int j = 0; j < sizeof(int4) / sizeof(fp8type); j++)
s[j] += hscale * (half_dtype)(x[j]);
}
(reinterpret_cast<int4 *>(outbuf))[index1_out] = sum[0];
(reinterpret_cast<int4 *>(outbuf))[index2_out] = sum[1];
......@@ -988,11 +989,11 @@ __global__ void __launch_bounds__(MAX_THREADS) userbuffers_fp16_sum_inplace_gpu_
}
int4 sum = val[0];
half *s = reinterpret_cast<half *>(&sum);
half_dtype *s = reinterpret_cast<half_dtype *>(&sum);
#pragma unroll
for (int i = 1; i < RANKS; i++) {
half *x = reinterpret_cast<half *>(&val[i]);
half_dtype *x = reinterpret_cast<half_dtype *>(&val[i]);
#pragma unroll
for (int j = 0; j < 8; j++) s[j] += x[j];
}
......@@ -1078,11 +1079,11 @@ __global__ void __launch_bounds__(MAX_THREADS)
}
int4 sum = val[0];
half *s = reinterpret_cast<half *>(&sum);
half_dtype *s = reinterpret_cast<half_dtype *>(&sum);
#pragma unroll
for (int i = 1; i < RANKS; i++) {
half *x = reinterpret_cast<half *>(&val[i]);
half_dtype *x = reinterpret_cast<half_dtype *>(&val[i]);
#pragma unroll
for (int j = 0; j < 8; j++) s[j] += x[j];
}
......@@ -1169,11 +1170,11 @@ __global__ void __launch_bounds__(MAX_THREADS)
}
int4 sum = val[0];
half *s = reinterpret_cast<half *>(&sum);
half_dtype *s = reinterpret_cast<half_dtype *>(&sum);
#pragma unroll
for (int i = 1; i < RANKS; i++) {
half *x = reinterpret_cast<half *>(&val[i]);
half_dtype *x = reinterpret_cast<half_dtype *>(&val[i]);
#pragma unroll
for (int j = 0; j < 8; j++) s[j] += x[j];
}
......@@ -2597,30 +2598,57 @@ void reset_counters(void *atomic_ptr, int num_chunks, bool allgather, cudaStream
reset_counters_kernel<<<grid, block, 0, stream>>>(atomic_ptr, num_chunks, allgather);
}
template <typename fp8type>
template <typename fp8type, int nvec>
__global__ void __launch_bounds__(MAX_THREADS / 4)
reduce_fp8_in_bf16_out_cuda(void *inputs, void *output, const float *scale,
const int num_inputs, const int input_size) {
const size_t tid = threadIdx.x + blockDim.x * blockIdx.x;
const int num_inputs, const int input_size,
const int num_aligned_elements_per_input,
const int tot_input_size) {
fp8type *inputs_fp8 = reinterpret_cast<fp8type *>(inputs);
float accum_buf = static_cast<float>(inputs_fp8[tid]) * (*scale);
half_dtype *output_half = reinterpret_cast<half_dtype *>(output);
transformer_engine::VectorizedLoader<fp8type, nvec, true> loader(inputs_fp8, tot_input_size);
transformer_engine::VectorizedStorer<half_dtype, nvec, true> storer(output_half, input_size);
const size_t tid = threadIdx.x + blockDim.x * blockIdx.x;
if (tid >= num_aligned_elements_per_input) {
return;
}
float accum_buf[nvec];
loader.load(tid, tot_input_size);
#pragma unroll
for (int i = 1; i < num_inputs; i++) {
accum_buf += static_cast<float>(inputs_fp8[tid + input_size * i]) * (*scale);
for (int i = 0; i < nvec; ++i) {
accum_buf[i] = static_cast<float>(loader.separate()[i]) * (*scale);
}
half *output_half = reinterpret_cast<half *>(output);
output_half[tid] = (half)accum_buf;
for (int input_id = 1; input_id < num_inputs; ++input_id) {
loader.load(tid + num_aligned_elements_per_input * input_id, tot_input_size);
#pragma unroll
for (int i = 0; i < nvec; ++i) {
accum_buf[i] += static_cast<float>(loader.separate()[i]) * (*scale);
}
}
#pragma unroll
for (int i = 0; i < nvec; ++i) {
storer.separate()[i] = static_cast<half_dtype>(accum_buf[i]);
}
storer.store(tid, input_size);
}
template <typename fp8type>
void reduce_fp8_in_bf16_out(void *inputs, void *output, float *scale, int num_inputs,
int input_size, cudaStream_t stream) {
constexpr int nvec = 32;
assert(input_size % nvec == 0);
const int num_aligned_elements_per_input = input_size / nvec;
const int tot_input_size = input_size * num_inputs;
size_t num_threads = MAX_THREADS / 4;
size_t num_blocks = (input_size + num_threads - 1) / num_threads;
size_t num_blocks = (num_aligned_elements_per_input + num_threads - 1) / num_threads;
dim3 block(num_threads);
dim3 grid(num_blocks);
reduce_fp8_in_bf16_out_cuda<fp8type>
<<<grid, block, 0, stream>>>(inputs, output, scale, num_inputs, input_size);
reduce_fp8_in_bf16_out_cuda<fp8type, nvec>
<<<grid, block, 0, stream>>>(inputs, output, scale, num_inputs, input_size,
num_aligned_elements_per_input, tot_input_size);
}
template void reduce_fp8_in_bf16_out<__nv_fp8_e4m3>(void *inputs, void *output, float *scale,
......@@ -2630,23 +2658,50 @@ template void reduce_fp8_in_bf16_out<__nv_fp8_e5m2>(void *inputs, void *output,
int num_inputs, int input_size,
cudaStream_t stream);
template <int nvec>
__global__ void __launch_bounds__(MAX_THREADS / 4)
reduce_bf16_cuda(void *inputs, void *output, const int num_inputs, const int input_size) {
reduce_bf16_cuda(void *inputs, void *output, const int num_inputs, const int input_size,
const int num_aligned_elements_per_input, const int tot_input_size) {
half_dtype *inputs_half = reinterpret_cast<half_dtype *>(inputs);
half_dtype *output_half = reinterpret_cast<half_dtype *>(output);
transformer_engine::VectorizedLoader<half_dtype, nvec, true> loader(inputs_half, tot_input_size);
transformer_engine::VectorizedStorer<half_dtype, nvec, true> storer(output_half, input_size);
const size_t tid = threadIdx.x + blockDim.x * blockIdx.x;
half *inputs_half = reinterpret_cast<half *>(inputs);
float accum_buf = static_cast<float>(inputs_half[tid]);
if (tid >= num_aligned_elements_per_input) {
return;
}
float accum_buf[nvec];
loader.load(tid, tot_input_size);
#pragma unroll
for (int i = 0; i < nvec; ++i) {
accum_buf[i] = static_cast<float>(loader.separate()[i]);
}
for (int input_id = 1; input_id < num_inputs; ++input_id) {
loader.load(tid + num_aligned_elements_per_input * input_id, tot_input_size);
#pragma unroll
for (int i = 0; i < nvec; ++i) {
accum_buf[i] += static_cast<float>(loader.separate()[i]);
}
}
#pragma unroll
for (int i = 1; i < num_inputs; i++) {
accum_buf += static_cast<float>(inputs_half[tid + input_size * i]);
for (int i = 0; i < nvec; ++i) {
storer.separate()[i] = static_cast<half_dtype>(accum_buf[i]);
}
half *output_half = reinterpret_cast<half *>(output);
output_half[tid] = (half)accum_buf;
storer.store(tid, input_size);
}
void reduce_bf16(void *inputs, void *output, int num_inputs, int input_size, cudaStream_t stream) {
constexpr int nvec = 32;
assert(input_size % nvec == 0);
const int num_aligned_elements_per_input = input_size / nvec;
const int tot_input_size = input_size * num_inputs;
size_t num_threads = MAX_THREADS / 4;
size_t num_blocks = (input_size + num_threads - 1) / num_threads;
size_t num_blocks = (num_aligned_elements_per_input + num_threads - 1) / num_threads;
dim3 block(num_threads);
dim3 grid(num_blocks);
reduce_bf16_cuda<<<grid, block, 0, stream>>>(inputs, output, num_inputs, input_size);
reduce_bf16_cuda<nvec><<<grid, block, 0, stream>>>(
inputs, output, num_inputs, input_size, num_aligned_elements_per_input, tot_input_size);
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment