"git@developer.sourcefind.cn:OpenDAS/TransformerEngine.git" did not exist on "f5412e5f5a804a14f80993310622b6088598412f"
Unverified Commit 9654931c authored by Sangkug Lym's avatar Sangkug Lym Committed by GitHub
Browse files

Support vectorized local reduction for p2p-based ReduceScatter overlap (#1452)



* Support vectorized local reduction for p2p-based ReduceScatter overlap
Signed-off-by: default avatarSangkug Lym <slym@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* cleanup
Signed-off-by: default avatarSangkug Lym <slym@nvidia.com>

---------
Signed-off-by: default avatarSangkug Lym <slym@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent eb28c650
...@@ -9,10 +9,9 @@ ...@@ -9,10 +9,9 @@
#include <cuda_runtime.h> #include <cuda_runtime.h>
#if __CUDA_ARCH__ >= 800 #if __CUDA_ARCH__ >= 800
#include <cuda_bf16.h> #define half_dtype nv_bfloat16
#define half nv_bfloat16
#else #else
#include <cuda_fp16.h> #define half_dtype half
#endif #endif
#include <assert.h> #include <assert.h>
...@@ -20,6 +19,7 @@ ...@@ -20,6 +19,7 @@
#include <unistd.h> #include <unistd.h>
#include "common/util/system.h" #include "common/util/system.h"
#include "common/util/vectorized_pointwise.h"
#include "userbuffers.h" #include "userbuffers.h"
#define MAX_THREADS 1024 #define MAX_THREADS 1024
...@@ -116,11 +116,11 @@ __global__ void __launch_bounds__(MAX_THREADS) ...@@ -116,11 +116,11 @@ __global__ void __launch_bounds__(MAX_THREADS)
} }
int4 sum = val[0]; int4 sum = val[0];
half *s = reinterpret_cast<half *>(&sum); half_dtype *s = reinterpret_cast<half_dtype *>(&sum);
#pragma unroll #pragma unroll
for (int i = 1; i < RANKS; i++) { for (int i = 1; i < RANKS; i++) {
half *x = reinterpret_cast<half *>(&val[i]); half_dtype *x = reinterpret_cast<half_dtype *>(&val[i]);
#pragma unroll #pragma unroll
for (int j = 0; j < 8; j++) s[j] += x[j]; for (int j = 0; j < 8; j++) s[j] += x[j];
} }
...@@ -200,11 +200,11 @@ __global__ void __launch_bounds__(MAX_THREADS) ...@@ -200,11 +200,11 @@ __global__ void __launch_bounds__(MAX_THREADS)
} }
int4 sum = val[0]; int4 sum = val[0];
half *s = reinterpret_cast<half *>(&sum); half_dtype *s = reinterpret_cast<half_dtype *>(&sum);
#pragma unroll #pragma unroll
for (int i = 1; i < RANKS; i++) { for (int i = 1; i < RANKS; i++) {
half *x = reinterpret_cast<half *>(&val[i]); half_dtype *x = reinterpret_cast<half_dtype *>(&val[i]);
#pragma unroll #pragma unroll
for (int j = 0; j < 8; j++) s[j] += x[j]; for (int j = 0; j < 8; j++) s[j] += x[j];
} }
...@@ -311,11 +311,11 @@ __global__ void __launch_bounds__(MAX_THREADS) ...@@ -311,11 +311,11 @@ __global__ void __launch_bounds__(MAX_THREADS)
} }
int4 sum = val[0]; int4 sum = val[0];
half *s = reinterpret_cast<half *>(&sum); half_dtype *s = reinterpret_cast<half_dtype *>(&sum);
#pragma unroll #pragma unroll
for (int i = 1; i < RANKS; i++) { for (int i = 1; i < RANKS; i++) {
half *x = reinterpret_cast<half *>(&val[i]); half_dtype *x = reinterpret_cast<half_dtype *>(&val[i]);
#pragma unroll #pragma unroll
for (int j = 0; j < 8; j++) s[j] += x[j]; for (int j = 0; j < 8; j++) s[j] += x[j];
} }
...@@ -378,11 +378,11 @@ __global__ void __launch_bounds__(MAX_THREADS) userbuffers_fp16_sum_inplace_gpu_ ...@@ -378,11 +378,11 @@ __global__ void __launch_bounds__(MAX_THREADS) userbuffers_fp16_sum_inplace_gpu_
} }
int4 sum = val[0]; int4 sum = val[0];
half *s = reinterpret_cast<half *>(&sum); half_dtype *s = reinterpret_cast<half_dtype *>(&sum);
#pragma unroll #pragma unroll
for (int i = 1; i < RANKS; i++) { for (int i = 1; i < RANKS; i++) {
half *x = reinterpret_cast<half *>(&val[i]); half_dtype *x = reinterpret_cast<half_dtype *>(&val[i]);
#pragma unroll #pragma unroll
for (int j = 0; j < 8; j++) s[j] += x[j]; for (int j = 0; j < 8; j++) s[j] += x[j];
} }
...@@ -780,7 +780,7 @@ __global__ void __launch_bounds__(MAX_THREADS) userbuffers_fp16_sum_inplace_gpu_ ...@@ -780,7 +780,7 @@ __global__ void __launch_bounds__(MAX_THREADS) userbuffers_fp16_sum_inplace_gpu_
int physgpu, targetgpu, *myptr; int physgpu, targetgpu, *myptr;
int *reduceidptr, reduce_id; int *reduceidptr, reduce_id;
int lastSM = 0; int lastSM = 0;
half hscale = (half)*scale; half_dtype hscale = (half_dtype)*scale;
if (threadIdx.x < RANKS) { if (threadIdx.x < RANKS) {
physgpu = myrank * gpustep + firstrank; physgpu = myrank * gpustep + firstrank;
...@@ -823,13 +823,13 @@ __global__ void __launch_bounds__(MAX_THREADS) userbuffers_fp16_sum_inplace_gpu_ ...@@ -823,13 +823,13 @@ __global__ void __launch_bounds__(MAX_THREADS) userbuffers_fp16_sum_inplace_gpu_
} }
int4 sum[2] = {{0, 0, 0, 0}, {0, 0, 0, 0}}; int4 sum[2] = {{0, 0, 0, 0}, {0, 0, 0, 0}};
half *s = reinterpret_cast<half *>(&sum); half_dtype *s = reinterpret_cast<half_dtype *>(&sum);
#pragma unroll #pragma unroll
for (int i = 0; i < RANKS; i++) { for (int i = 0; i < RANKS; i++) {
fp8type *x = reinterpret_cast<fp8type *>(&val[i]); fp8type *x = reinterpret_cast<fp8type *>(&val[i]);
#pragma unroll #pragma unroll
for (int j = 0; j < sizeof(int4) / sizeof(fp8type); j++) s[j] += hscale * (half)(x[j]); for (int j = 0; j < sizeof(int4) / sizeof(fp8type); j++) s[j] += hscale * (half_dtype)(x[j]);
} }
int hline = 2 * line; int hline = 2 * line;
(reinterpret_cast<int4 *>(outbuf))[(hline / rowlines) * skiplines + (hline % rowlines)] = (reinterpret_cast<int4 *>(outbuf))[(hline / rowlines) * skiplines + (hline % rowlines)] =
...@@ -855,7 +855,7 @@ __global__ void __launch_bounds__(MAX_THREADS) ...@@ -855,7 +855,7 @@ __global__ void __launch_bounds__(MAX_THREADS)
int physgpu, targetgpu, *myptr; int physgpu, targetgpu, *myptr;
int *reduceidptr, reduce_id; int *reduceidptr, reduce_id;
int lastSM = 0; int lastSM = 0;
half hscale = (half)*scale; half_dtype hscale = (half_dtype)*scale;
if (threadIdx.x < RANKS) { if (threadIdx.x < RANKS) {
physgpu = myrank * gpustep + firstrank; physgpu = myrank * gpustep + firstrank;
...@@ -919,13 +919,14 @@ __global__ void __launch_bounds__(MAX_THREADS) ...@@ -919,13 +919,14 @@ __global__ void __launch_bounds__(MAX_THREADS)
} }
int4 sum[2] = {{0, 0, 0, 0}, {0, 0, 0, 0}}; int4 sum[2] = {{0, 0, 0, 0}, {0, 0, 0, 0}};
half *s = reinterpret_cast<half *>(&sum); half_dtype *s = reinterpret_cast<half_dtype *>(&sum);
#pragma unroll #pragma unroll
for (int i = 0; i < RANKS; i++) { for (int i = 0; i < RANKS; i++) {
fp8type *x = reinterpret_cast<fp8type *>(&val[i]); fp8type *x = reinterpret_cast<fp8type *>(&val[i]);
#pragma unroll #pragma unroll
for (int j = 0; j < sizeof(int4) / sizeof(fp8type); j++) s[j] += hscale * (half)(x[j]); for (int j = 0; j < sizeof(int4) / sizeof(fp8type); j++)
s[j] += hscale * (half_dtype)(x[j]);
} }
(reinterpret_cast<int4 *>(outbuf))[index1_out] = sum[0]; (reinterpret_cast<int4 *>(outbuf))[index1_out] = sum[0];
(reinterpret_cast<int4 *>(outbuf))[index2_out] = sum[1]; (reinterpret_cast<int4 *>(outbuf))[index2_out] = sum[1];
...@@ -988,11 +989,11 @@ __global__ void __launch_bounds__(MAX_THREADS) userbuffers_fp16_sum_inplace_gpu_ ...@@ -988,11 +989,11 @@ __global__ void __launch_bounds__(MAX_THREADS) userbuffers_fp16_sum_inplace_gpu_
} }
int4 sum = val[0]; int4 sum = val[0];
half *s = reinterpret_cast<half *>(&sum); half_dtype *s = reinterpret_cast<half_dtype *>(&sum);
#pragma unroll #pragma unroll
for (int i = 1; i < RANKS; i++) { for (int i = 1; i < RANKS; i++) {
half *x = reinterpret_cast<half *>(&val[i]); half_dtype *x = reinterpret_cast<half_dtype *>(&val[i]);
#pragma unroll #pragma unroll
for (int j = 0; j < 8; j++) s[j] += x[j]; for (int j = 0; j < 8; j++) s[j] += x[j];
} }
...@@ -1078,11 +1079,11 @@ __global__ void __launch_bounds__(MAX_THREADS) ...@@ -1078,11 +1079,11 @@ __global__ void __launch_bounds__(MAX_THREADS)
} }
int4 sum = val[0]; int4 sum = val[0];
half *s = reinterpret_cast<half *>(&sum); half_dtype *s = reinterpret_cast<half_dtype *>(&sum);
#pragma unroll #pragma unroll
for (int i = 1; i < RANKS; i++) { for (int i = 1; i < RANKS; i++) {
half *x = reinterpret_cast<half *>(&val[i]); half_dtype *x = reinterpret_cast<half_dtype *>(&val[i]);
#pragma unroll #pragma unroll
for (int j = 0; j < 8; j++) s[j] += x[j]; for (int j = 0; j < 8; j++) s[j] += x[j];
} }
...@@ -1169,11 +1170,11 @@ __global__ void __launch_bounds__(MAX_THREADS) ...@@ -1169,11 +1170,11 @@ __global__ void __launch_bounds__(MAX_THREADS)
} }
int4 sum = val[0]; int4 sum = val[0];
half *s = reinterpret_cast<half *>(&sum); half_dtype *s = reinterpret_cast<half_dtype *>(&sum);
#pragma unroll #pragma unroll
for (int i = 1; i < RANKS; i++) { for (int i = 1; i < RANKS; i++) {
half *x = reinterpret_cast<half *>(&val[i]); half_dtype *x = reinterpret_cast<half_dtype *>(&val[i]);
#pragma unroll #pragma unroll
for (int j = 0; j < 8; j++) s[j] += x[j]; for (int j = 0; j < 8; j++) s[j] += x[j];
} }
...@@ -2597,30 +2598,57 @@ void reset_counters(void *atomic_ptr, int num_chunks, bool allgather, cudaStream ...@@ -2597,30 +2598,57 @@ void reset_counters(void *atomic_ptr, int num_chunks, bool allgather, cudaStream
reset_counters_kernel<<<grid, block, 0, stream>>>(atomic_ptr, num_chunks, allgather); reset_counters_kernel<<<grid, block, 0, stream>>>(atomic_ptr, num_chunks, allgather);
} }
template <typename fp8type> template <typename fp8type, int nvec>
__global__ void __launch_bounds__(MAX_THREADS / 4) __global__ void __launch_bounds__(MAX_THREADS / 4)
reduce_fp8_in_bf16_out_cuda(void *inputs, void *output, const float *scale, reduce_fp8_in_bf16_out_cuda(void *inputs, void *output, const float *scale,
const int num_inputs, const int input_size) { const int num_inputs, const int input_size,
const size_t tid = threadIdx.x + blockDim.x * blockIdx.x; const int num_aligned_elements_per_input,
const int tot_input_size) {
fp8type *inputs_fp8 = reinterpret_cast<fp8type *>(inputs); fp8type *inputs_fp8 = reinterpret_cast<fp8type *>(inputs);
float accum_buf = static_cast<float>(inputs_fp8[tid]) * (*scale); half_dtype *output_half = reinterpret_cast<half_dtype *>(output);
transformer_engine::VectorizedLoader<fp8type, nvec, true> loader(inputs_fp8, tot_input_size);
transformer_engine::VectorizedStorer<half_dtype, nvec, true> storer(output_half, input_size);
const size_t tid = threadIdx.x + blockDim.x * blockIdx.x;
if (tid >= num_aligned_elements_per_input) {
return;
}
float accum_buf[nvec];
loader.load(tid, tot_input_size);
#pragma unroll #pragma unroll
for (int i = 1; i < num_inputs; i++) { for (int i = 0; i < nvec; ++i) {
accum_buf += static_cast<float>(inputs_fp8[tid + input_size * i]) * (*scale); accum_buf[i] = static_cast<float>(loader.separate()[i]) * (*scale);
} }
half *output_half = reinterpret_cast<half *>(output); for (int input_id = 1; input_id < num_inputs; ++input_id) {
output_half[tid] = (half)accum_buf; loader.load(tid + num_aligned_elements_per_input * input_id, tot_input_size);
#pragma unroll
for (int i = 0; i < nvec; ++i) {
accum_buf[i] += static_cast<float>(loader.separate()[i]) * (*scale);
}
}
#pragma unroll
for (int i = 0; i < nvec; ++i) {
storer.separate()[i] = static_cast<half_dtype>(accum_buf[i]);
}
storer.store(tid, input_size);
} }
template <typename fp8type> template <typename fp8type>
void reduce_fp8_in_bf16_out(void *inputs, void *output, float *scale, int num_inputs, void reduce_fp8_in_bf16_out(void *inputs, void *output, float *scale, int num_inputs,
int input_size, cudaStream_t stream) { int input_size, cudaStream_t stream) {
constexpr int nvec = 32;
assert(input_size % nvec == 0);
const int num_aligned_elements_per_input = input_size / nvec;
const int tot_input_size = input_size * num_inputs;
size_t num_threads = MAX_THREADS / 4; size_t num_threads = MAX_THREADS / 4;
size_t num_blocks = (input_size + num_threads - 1) / num_threads; size_t num_blocks = (num_aligned_elements_per_input + num_threads - 1) / num_threads;
dim3 block(num_threads); dim3 block(num_threads);
dim3 grid(num_blocks); dim3 grid(num_blocks);
reduce_fp8_in_bf16_out_cuda<fp8type> reduce_fp8_in_bf16_out_cuda<fp8type, nvec>
<<<grid, block, 0, stream>>>(inputs, output, scale, num_inputs, input_size); <<<grid, block, 0, stream>>>(inputs, output, scale, num_inputs, input_size,
num_aligned_elements_per_input, tot_input_size);
} }
template void reduce_fp8_in_bf16_out<__nv_fp8_e4m3>(void *inputs, void *output, float *scale, template void reduce_fp8_in_bf16_out<__nv_fp8_e4m3>(void *inputs, void *output, float *scale,
...@@ -2630,23 +2658,50 @@ template void reduce_fp8_in_bf16_out<__nv_fp8_e5m2>(void *inputs, void *output, ...@@ -2630,23 +2658,50 @@ template void reduce_fp8_in_bf16_out<__nv_fp8_e5m2>(void *inputs, void *output,
int num_inputs, int input_size, int num_inputs, int input_size,
cudaStream_t stream); cudaStream_t stream);
template <int nvec>
__global__ void __launch_bounds__(MAX_THREADS / 4) __global__ void __launch_bounds__(MAX_THREADS / 4)
reduce_bf16_cuda(void *inputs, void *output, const int num_inputs, const int input_size) { reduce_bf16_cuda(void *inputs, void *output, const int num_inputs, const int input_size,
const int num_aligned_elements_per_input, const int tot_input_size) {
half_dtype *inputs_half = reinterpret_cast<half_dtype *>(inputs);
half_dtype *output_half = reinterpret_cast<half_dtype *>(output);
transformer_engine::VectorizedLoader<half_dtype, nvec, true> loader(inputs_half, tot_input_size);
transformer_engine::VectorizedStorer<half_dtype, nvec, true> storer(output_half, input_size);
const size_t tid = threadIdx.x + blockDim.x * blockIdx.x; const size_t tid = threadIdx.x + blockDim.x * blockIdx.x;
half *inputs_half = reinterpret_cast<half *>(inputs); if (tid >= num_aligned_elements_per_input) {
float accum_buf = static_cast<float>(inputs_half[tid]); return;
}
float accum_buf[nvec];
loader.load(tid, tot_input_size);
#pragma unroll
for (int i = 0; i < nvec; ++i) {
accum_buf[i] = static_cast<float>(loader.separate()[i]);
}
for (int input_id = 1; input_id < num_inputs; ++input_id) {
loader.load(tid + num_aligned_elements_per_input * input_id, tot_input_size);
#pragma unroll
for (int i = 0; i < nvec; ++i) {
accum_buf[i] += static_cast<float>(loader.separate()[i]);
}
}
#pragma unroll #pragma unroll
for (int i = 1; i < num_inputs; i++) { for (int i = 0; i < nvec; ++i) {
accum_buf += static_cast<float>(inputs_half[tid + input_size * i]); storer.separate()[i] = static_cast<half_dtype>(accum_buf[i]);
} }
half *output_half = reinterpret_cast<half *>(output); storer.store(tid, input_size);
output_half[tid] = (half)accum_buf;
} }
void reduce_bf16(void *inputs, void *output, int num_inputs, int input_size, cudaStream_t stream) { void reduce_bf16(void *inputs, void *output, int num_inputs, int input_size, cudaStream_t stream) {
constexpr int nvec = 32;
assert(input_size % nvec == 0);
const int num_aligned_elements_per_input = input_size / nvec;
const int tot_input_size = input_size * num_inputs;
size_t num_threads = MAX_THREADS / 4; size_t num_threads = MAX_THREADS / 4;
size_t num_blocks = (input_size + num_threads - 1) / num_threads; size_t num_blocks = (num_aligned_elements_per_input + num_threads - 1) / num_threads;
dim3 block(num_threads); dim3 block(num_threads);
dim3 grid(num_blocks); dim3 grid(num_blocks);
reduce_bf16_cuda<<<grid, block, 0, stream>>>(inputs, output, num_inputs, input_size); reduce_bf16_cuda<nvec><<<grid, block, 0, stream>>>(
inputs, output, num_inputs, input_size, num_aligned_elements_per_input, tot_input_size);
} }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment