Commit 41199996 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.12.0' into v0.12.0-dev

parents 31021d81 4fd9d6a8
...@@ -67,8 +67,9 @@ void cutlass_scaled_mm_sm100(torch::Tensor& c, torch::Tensor const& a, ...@@ -67,8 +67,9 @@ void cutlass_scaled_mm_sm100(torch::Tensor& c, torch::Tensor const& a,
std::optional<torch::Tensor> const& bias); std::optional<torch::Tensor> const& bias);
#endif #endif
#if defined(ENABLE_SCALED_MM_SM90) && ENABLE_SCALED_MM_SM90 || \ #if (defined(ENABLE_CUTLASS_MOE_SM90) && ENABLE_CUTLASS_MOE_SM90) || \
defined(ENABLE_SCALED_MM_SM100) && ENABLE_SCALED_MM_SM100 (defined(ENABLE_CUTLASS_MOE_SM100) && ENABLE_CUTLASS_MOE_SM100) || \
(defined(ENABLE_CUTLASS_MOE_SM120) && ENABLE_CUTLASS_MOE_SM120)
void get_cutlass_moe_mm_data_caller( void get_cutlass_moe_mm_data_caller(
const torch::Tensor& topk_ids, torch::Tensor& expert_offsets, const torch::Tensor& topk_ids, torch::Tensor& expert_offsets,
torch::Tensor& problem_sizes1, torch::Tensor& problem_sizes2, torch::Tensor& problem_sizes1, torch::Tensor& problem_sizes2,
...@@ -253,7 +254,7 @@ void cutlass_moe_mm( ...@@ -253,7 +254,7 @@ void cutlass_moe_mm(
bool per_act_token, bool per_out_ch) { bool per_act_token, bool per_out_ch) {
int32_t version_num = get_sm_version_num(); int32_t version_num = get_sm_version_num();
#if defined ENABLE_CUTLASS_MOE_SM100 && ENABLE_CUTLASS_MOE_SM100 #if defined ENABLE_CUTLASS_MOE_SM100 && ENABLE_CUTLASS_MOE_SM100
if (version_num >= 100) { if (version_num >= 100 && version_num < 110) {
cutlass_moe_mm_sm100(out_tensors, a_tensors, b_tensors, a_scales, b_scales, cutlass_moe_mm_sm100(out_tensors, a_tensors, b_tensors, a_scales, b_scales,
expert_offsets, problem_sizes, a_strides, b_strides, expert_offsets, problem_sizes, a_strides, b_strides,
c_strides, per_act_token, per_out_ch); c_strides, per_act_token, per_out_ch);
...@@ -261,7 +262,7 @@ void cutlass_moe_mm( ...@@ -261,7 +262,7 @@ void cutlass_moe_mm(
} }
#endif #endif
#if defined ENABLE_CUTLASS_MOE_SM90 && ENABLE_CUTLASS_MOE_SM90 #if defined ENABLE_CUTLASS_MOE_SM90 && ENABLE_CUTLASS_MOE_SM90
if (version_num >= 90) { if (version_num >= 90 && version_num < 100) {
cutlass_moe_mm_sm90(out_tensors, a_tensors, b_tensors, a_scales, b_scales, cutlass_moe_mm_sm90(out_tensors, a_tensors, b_tensors, a_scales, b_scales,
expert_offsets, problem_sizes, a_strides, b_strides, expert_offsets, problem_sizes, a_strides, b_strides,
c_strides, per_act_token, per_out_ch); c_strides, per_act_token, per_out_ch);
...@@ -283,8 +284,9 @@ void get_cutlass_moe_mm_data( ...@@ -283,8 +284,9 @@ void get_cutlass_moe_mm_data(
// This function currently gets compiled only if we have a valid cutlass moe // This function currently gets compiled only if we have a valid cutlass moe
// mm to run it for. // mm to run it for.
int32_t version_num = get_sm_version_num(); int32_t version_num = get_sm_version_num();
#if (defined ENABLE_CUTLASS_MOE_SM90 && ENABLE_CUTLASS_MOE_SM90) || \ #if (defined ENABLE_CUTLASS_MOE_SM90 && ENABLE_CUTLASS_MOE_SM90) || \
(defined ENABLE_CUTLASS_MOE_SM100 && ENABLE_CUTLASS_MOE_SM100) (defined ENABLE_CUTLASS_MOE_SM100 && ENABLE_CUTLASS_MOE_SM100) || \
(defined ENABLE_CUTLASS_MOE_SM120 && ENABLE_CUTLASS_MOE_SM120)
get_cutlass_moe_mm_data_caller(topk_ids, expert_offsets, problem_sizes1, get_cutlass_moe_mm_data_caller(topk_ids, expert_offsets, problem_sizes1,
problem_sizes2, input_permutation, problem_sizes2, input_permutation,
output_permutation, num_experts, n, k, output_permutation, num_experts, n, k,
...@@ -295,7 +297,7 @@ void get_cutlass_moe_mm_data( ...@@ -295,7 +297,7 @@ void get_cutlass_moe_mm_data(
false, false,
"No compiled get_cutlass_moe_mm_data: no cutlass_scaled_mm kernel for " "No compiled get_cutlass_moe_mm_data: no cutlass_scaled_mm kernel for "
"CUDA device capability: ", "CUDA device capability: ",
version_num, ". Required capability: 90 or 100"); version_num, ". Required capability: 90, 100, or 120");
} }
void get_cutlass_moe_mm_problem_sizes( void get_cutlass_moe_mm_problem_sizes(
...@@ -303,8 +305,9 @@ void get_cutlass_moe_mm_problem_sizes( ...@@ -303,8 +305,9 @@ void get_cutlass_moe_mm_problem_sizes(
torch::Tensor& problem_sizes2, const int64_t num_experts, const int64_t n, torch::Tensor& problem_sizes2, const int64_t num_experts, const int64_t n,
const int64_t k, const std::optional<torch::Tensor>& blockscale_offsets) { const int64_t k, const std::optional<torch::Tensor>& blockscale_offsets) {
int32_t version_num = get_sm_version_num(); int32_t version_num = get_sm_version_num();
#if (defined ENABLE_CUTLASS_MOE_SM90 && ENABLE_CUTLASS_MOE_SM90) || \ #if (defined ENABLE_CUTLASS_MOE_SM90 && ENABLE_CUTLASS_MOE_SM90) || \
(defined ENABLE_CUTLASS_MOE_SM100 && ENABLE_CUTLASS_MOE_SM100) (defined ENABLE_CUTLASS_MOE_SM100 && ENABLE_CUTLASS_MOE_SM100) || \
(defined ENABLE_CUTLASS_MOE_SM120 && ENABLE_CUTLASS_MOE_SM120)
get_cutlass_moe_mm_problem_sizes_caller(topk_ids, problem_sizes1, get_cutlass_moe_mm_problem_sizes_caller(topk_ids, problem_sizes1,
problem_sizes2, num_experts, n, k, problem_sizes2, num_experts, n, k,
blockscale_offsets); blockscale_offsets);
...@@ -314,7 +317,7 @@ void get_cutlass_moe_mm_problem_sizes( ...@@ -314,7 +317,7 @@ void get_cutlass_moe_mm_problem_sizes(
false, false,
"No compiled get_cutlass_moe_mm_problem_sizes: no cutlass_scaled_mm " "No compiled get_cutlass_moe_mm_problem_sizes: no cutlass_scaled_mm "
"kernel for CUDA device capability: ", "kernel for CUDA device capability: ",
version_num, ". Required capability: 90 or 100"); version_num, ". Required capability: 90, 100, or 120");
} }
void get_cutlass_pplx_moe_mm_data(torch::Tensor& expert_offsets, void get_cutlass_pplx_moe_mm_data(torch::Tensor& expert_offsets,
...@@ -327,8 +330,9 @@ void get_cutlass_pplx_moe_mm_data(torch::Tensor& expert_offsets, ...@@ -327,8 +330,9 @@ void get_cutlass_pplx_moe_mm_data(torch::Tensor& expert_offsets,
// This function currently gets compiled only if we have a valid cutlass moe // This function currently gets compiled only if we have a valid cutlass moe
// mm to run it for. // mm to run it for.
int32_t version_num = get_sm_version_num(); int32_t version_num = get_sm_version_num();
#if (defined ENABLE_CUTLASS_MOE_SM90 && ENABLE_CUTLASS_MOE_SM90) || \ #if (defined ENABLE_CUTLASS_MOE_SM90 && ENABLE_CUTLASS_MOE_SM90) || \
(defined ENABLE_CUTLASS_MOE_SM100 && ENABLE_CUTLASS_MOE_SM100) (defined ENABLE_CUTLASS_MOE_SM100 && ENABLE_CUTLASS_MOE_SM100) || \
(defined ENABLE_CUTLASS_MOE_SM120 && ENABLE_CUTLASS_MOE_SM120)
get_cutlass_pplx_moe_mm_data_caller(expert_offsets, problem_sizes1, get_cutlass_pplx_moe_mm_data_caller(expert_offsets, problem_sizes1,
problem_sizes2, expert_num_tokens, problem_sizes2, expert_num_tokens,
num_local_experts, padded_m, n, k); num_local_experts, padded_m, n, k);
...@@ -338,7 +342,7 @@ void get_cutlass_pplx_moe_mm_data(torch::Tensor& expert_offsets, ...@@ -338,7 +342,7 @@ void get_cutlass_pplx_moe_mm_data(torch::Tensor& expert_offsets,
false, false,
"No compiled get_cutlass_pplx_moe_mm_data: no cutlass_scaled_mm kernel " "No compiled get_cutlass_pplx_moe_mm_data: no cutlass_scaled_mm kernel "
"for CUDA device capability: ", "for CUDA device capability: ",
version_num, ". Required capability: 90 or 100"); version_num, ". Required capability: 90, 100, or 120");
} }
void cutlass_scaled_mm_azp(torch::Tensor& c, torch::Tensor const& a, void cutlass_scaled_mm_azp(torch::Tensor& c, torch::Tensor const& a,
......
...@@ -7,7 +7,7 @@ ...@@ -7,7 +7,7 @@
#include <hip/hip_bf16.h> #include <hip/hip_bf16.h>
#include <hip/hip_bfloat16.h> #include <hip/hip_bfloat16.h>
#include "../../../attention/attention_dtypes.h" #include "../../../../attention/attention_dtypes.h"
namespace vllm { namespace vllm {
#ifdef USE_ROCM #ifdef USE_ROCM
......
#include "common.cuh" #include "common.cuh"
#include "dispatch_utils.h" #include "dispatch_utils.h"
#include "../../cub_helpers.h" #include "cub_helpers.h"
#include "../vectorization_utils.cuh" #include "quantization/vectorization_utils.cuh"
#include <c10/cuda/CUDAGuard.h> #include <c10/cuda/CUDAGuard.h>
#include <ATen/cuda/Exceptions.h> #include <ATen/cuda/Exceptions.h>
......
#pragma once #pragma once
#include "../../../attention/attention_dtypes.h" #include "../../../../attention/attention_dtypes.h"
#include <assert.h> #include <assert.h>
#include <float.h> #include <float.h>
#include <stdint.h> #include <stdint.h>
......
#include <ATen/cuda/CUDAContext.h> #include <ATen/cuda/CUDAContext.h>
#include "../per_token_group_quant_8bit.h" #include "quantization/w8a8/per_token_group_quant_8bit.h"
#include <cmath> #include <cmath>
...@@ -8,9 +8,9 @@ ...@@ -8,9 +8,9 @@
#include <torch/all.h> #include <torch/all.h>
#include "../vectorization.cuh" #include "quantization/vectorization.cuh"
#include "../vectorization_utils.cuh" #include "quantization/vectorization_utils.cuh"
#include "../../dispatch_utils.h" #include "dispatch_utils.h"
__device__ __forceinline__ float GroupReduceMax(float val) { __device__ __forceinline__ float GroupReduceMax(float val) {
unsigned mask = threadIdx.x % 32 >= 16 ? 0xffff0000 : 0x0000ffff; unsigned mask = threadIdx.x % 32 >= 16 ? 0xffff0000 : 0x0000ffff;
...@@ -212,4 +212,4 @@ void per_token_group_quant_fp8(const torch::Tensor& input, ...@@ -212,4 +212,4 @@ void per_token_group_quant_fp8(const torch::Tensor& input,
double fp8_max, bool scale_ue8m0) { double fp8_max, bool scale_ue8m0) {
per_token_group_quant_8bit(input, output_q, output_s, group_size, eps, per_token_group_quant_8bit(input, output_q, output_s, group_size, eps,
fp8_min, fp8_max, scale_ue8m0); fp8_min, fp8_max, scale_ue8m0);
} }
\ No newline at end of file
#include <ATen/cuda/CUDAContext.h>
#include <torch/all.h>
#include "quantization/w8a8/per_token_group_quant_8bit.h"
void per_token_group_quant_int8(const torch::Tensor& input,
torch::Tensor& output_q,
torch::Tensor& output_s, int64_t group_size,
double eps, double int8_min, double int8_max) {
per_token_group_quant_8bit(input, output_q, output_s, group_size, eps,
int8_min, int8_max);
}
\ No newline at end of file
#include <ATen/cuda/CUDAContext.h> #include <ATen/cuda/CUDAContext.h>
#include <torch/all.h> #include <torch/all.h>
#include <c10/cuda/CUDAGuard.h>
#ifndef USE_ROCM
#include "../per_token_group_quant_8bit.h"
#endif
#include <cmath> #include <cmath>
#include "../../cub_helpers.h" #include "dispatch_utils.h"
#include "../../dispatch_utils.h" #include "quantization/vectorization_utils.cuh"
#include "../vectorization_utils.cuh" #include "cub_helpers.h"
static inline __device__ int8_t float_to_int8_rn(float x) { static inline __device__ int8_t float_to_int8_rn(float x) {
#ifdef USE_ROCM #ifdef USE_ROCM
...@@ -25,7 +22,6 @@ static inline __device__ int8_t float_to_int8_rn(float x) { ...@@ -25,7 +22,6 @@ static inline __device__ int8_t float_to_int8_rn(float x) {
float dst = std::nearbyint(x); float dst = std::nearbyint(x);
// saturate // saturate
// See https://github.com/pytorch/pytorch/issues/127666 // See https://github.com/pytorch/pytorch/issues/127666
// See https://github.com/llvm/llvm-project/issues/95183 // See https://github.com/llvm/llvm-project/issues/95183
// hip-clang std::clamp __glibcxx_assert_fail host function when building on // hip-clang std::clamp __glibcxx_assert_fail host function when building on
...@@ -84,7 +80,6 @@ static inline __device__ int8_t int32_to_int8(int32_t x) { ...@@ -84,7 +80,6 @@ static inline __device__ int8_t int32_to_int8(int32_t x) {
static_cast<int32_t>(std::numeric_limits<int8_t>::max()); static_cast<int32_t>(std::numeric_limits<int8_t>::max());
// saturate // saturate
// See https://github.com/pytorch/pytorch/issues/127666 // See https://github.com/pytorch/pytorch/issues/127666
// See https://github.com/llvm/llvm-project/issues/95183 // See https://github.com/llvm/llvm-project/issues/95183
// hip-clang std::clamp __glibcxx_assert_fail host function when building on // hip-clang std::clamp __glibcxx_assert_fail host function when building on
...@@ -176,7 +171,6 @@ __global__ void dynamic_scaled_int8_quant_kernel( ...@@ -176,7 +171,6 @@ __global__ void dynamic_scaled_int8_quant_kernel(
float inv_s = (absmax == 0.f) ? 0.f : 127.f / absmax; float inv_s = (absmax == 0.f) ? 0.f : 127.f / absmax;
// 2. quantize
vectorize_with_alignment<16>( vectorize_with_alignment<16>(
row_in, row_out, hidden_size, tid, stride, row_in, row_out, hidden_size, tid, stride,
[=] __device__(int8_t& dst, const scalar_t& src) { [=] __device__(int8_t& dst, const scalar_t& src) {
...@@ -194,7 +188,6 @@ struct MinMax { ...@@ -194,7 +188,6 @@ struct MinMax {
__host__ __device__ explicit MinMax(float v) : min(v), max(v) {} __host__ __device__ explicit MinMax(float v) : min(v), max(v) {}
// add a value to the MinMax
__host__ __device__ MinMax& operator+=(float v) { __host__ __device__ MinMax& operator+=(float v) {
min = fminf(min, v); min = fminf(min, v);
max = fmaxf(max, v); max = fmaxf(max, v);
...@@ -228,7 +221,6 @@ __global__ void dynamic_scaled_int8_azp_quant_kernel( ...@@ -228,7 +221,6 @@ __global__ void dynamic_scaled_int8_azp_quant_kernel(
const scalar_t* row_in = input + token_idx * hidden_size; const scalar_t* row_in = input + token_idx * hidden_size;
int8_t* row_out = output + token_idx * hidden_size; int8_t* row_out = output + token_idx * hidden_size;
// 1. calculate min & max
MinMax thread_mm; MinMax thread_mm;
vectorize_read_with_alignment<16>(row_in, hidden_size, tid, stride, vectorize_read_with_alignment<16>(row_in, hidden_size, tid, stride,
[&] __device__(const scalar_t& src) { [&] __device__(const scalar_t& src) {
...@@ -261,7 +253,6 @@ __global__ void dynamic_scaled_int8_azp_quant_kernel( ...@@ -261,7 +253,6 @@ __global__ void dynamic_scaled_int8_azp_quant_kernel(
const float inv_s = 1.f / scale_sh; const float inv_s = 1.f / scale_sh;
const azp_t azp = azp_sh; const azp_t azp = azp_sh;
// 2. quantize
vectorize_with_alignment<16>( vectorize_with_alignment<16>(
row_in, row_out, hidden_size, tid, stride, row_in, row_out, hidden_size, tid, stride,
[=] __device__(int8_t& dst, const scalar_t& src) { [=] __device__(int8_t& dst, const scalar_t& src) {
...@@ -285,6 +276,7 @@ void static_scaled_int8_quant(torch::Tensor& out, // [..., hidden_size] ...@@ -285,6 +276,7 @@ void static_scaled_int8_quant(torch::Tensor& out, // [..., hidden_size]
int const num_tokens = input.numel() / hidden_size; int const num_tokens = input.numel() / hidden_size;
dim3 const grid(num_tokens); dim3 const grid(num_tokens);
dim3 const block(std::min(hidden_size, 256)); dim3 const block(std::min(hidden_size, 256));
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
VLLM_DISPATCH_FLOATING_TYPES( VLLM_DISPATCH_FLOATING_TYPES(
input.scalar_type(), "static_scaled_int8_quant_kernel", [&] { input.scalar_type(), "static_scaled_int8_quant_kernel", [&] {
...@@ -316,6 +308,7 @@ void dynamic_scaled_int8_quant( ...@@ -316,6 +308,7 @@ void dynamic_scaled_int8_quant(
int const num_tokens = input.numel() / hidden_size; int const num_tokens = input.numel() / hidden_size;
dim3 const grid(num_tokens); dim3 const grid(num_tokens);
dim3 const block(std::min(hidden_size, 256)); dim3 const block(std::min(hidden_size, 256));
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
VLLM_DISPATCH_FLOATING_TYPES( VLLM_DISPATCH_FLOATING_TYPES(
input.scalar_type(), "dynamic_scaled_int8_quant_kernel", [&] { input.scalar_type(), "dynamic_scaled_int8_quant_kernel", [&] {
...@@ -332,14 +325,4 @@ void dynamic_scaled_int8_quant( ...@@ -332,14 +325,4 @@ void dynamic_scaled_int8_quant(
hidden_size); hidden_size);
} }
}); });
} }
\ No newline at end of file
#ifndef USE_ROCM
void per_token_group_quant_int8(const torch::Tensor& input,
torch::Tensor& output_q,
torch::Tensor& output_s, int64_t group_size,
double eps, double int8_min, double int8_max) {
per_token_group_quant_8bit(input, output_q, output_s, group_size, eps,
int8_min, int8_max);
}
#endif
#pragma once #pragma once
#include <torch/all.h> #include <torch/all.h>
// TODO(wentao): refactor the folder to 8bit, then includes fp8 and int8 folders
// 8-bit per-token-group quantization helper used by both FP8 and INT8 // 8-bit per-token-group quantization helper used by both FP8 and INT8
void per_token_group_quant_8bit(const torch::Tensor& input, void per_token_group_quant_8bit(const torch::Tensor& input,
torch::Tensor& output_q, torch::Tensor& output_q,
......
...@@ -22,13 +22,14 @@ template <typename AllReduceKernel, typename T> ...@@ -22,13 +22,14 @@ template <typename AllReduceKernel, typename T>
__global__ __quickreduce_launch_bounds_two_shot__ static void __global__ __quickreduce_launch_bounds_two_shot__ static void
allreduce_prototype_twoshot(T const* A, T* B, uint32_t N, uint32_t num_blocks, allreduce_prototype_twoshot(T const* A, T* B, uint32_t N, uint32_t num_blocks,
int rank, uint8_t** dbuffer_list, int rank, uint8_t** dbuffer_list,
uint32_t data_offset, uint32_t flag_color) { uint32_t data_offset, uint32_t flag_color,
int64_t data_size_per_phase) {
int block = blockIdx.x; int block = blockIdx.x;
int grid = gridDim.x; int grid = gridDim.x;
while (block < num_blocks) { while (block < num_blocks) {
AllReduceKernel::run(A, B, N, block, rank, dbuffer_list, data_offset, AllReduceKernel::run(A, B, N, block, rank, dbuffer_list, data_offset,
flag_color); flag_color, data_size_per_phase);
block += grid; block += grid;
flag_color++; flag_color++;
} }
...@@ -41,21 +42,21 @@ allreduce_prototype_twoshot(T const* A, T* B, uint32_t N, uint32_t num_blocks, ...@@ -41,21 +42,21 @@ allreduce_prototype_twoshot(T const* A, T* B, uint32_t N, uint32_t num_blocks,
hipLaunchKernelGGL((allreduce_prototype_twoshot<AllReduceKernel, T>), \ hipLaunchKernelGGL((allreduce_prototype_twoshot<AllReduceKernel, T>), \
dim3(grid), dim3(kBlockTwoShot), 0, stream, A, B, N, \ dim3(grid), dim3(kBlockTwoShot), 0, stream, A, B, N, \
num_blocks, rank, dbuffer_list, data_offset, \ num_blocks, rank, dbuffer_list, data_offset, \
flag_color); \ flag_color, this->kMaxProblemSize); \
} else if (world_size == 4) { \ } else if (world_size == 4) { \
using LineCodec = __codec<T, 4>; \ using LineCodec = __codec<T, 4>; \
using AllReduceKernel = AllReduceTwoshot<T, LineCodec, cast_bf2half>; \ using AllReduceKernel = AllReduceTwoshot<T, LineCodec, cast_bf2half>; \
hipLaunchKernelGGL((allreduce_prototype_twoshot<AllReduceKernel, T>), \ hipLaunchKernelGGL((allreduce_prototype_twoshot<AllReduceKernel, T>), \
dim3(grid), dim3(kBlockTwoShot), 0, stream, A, B, N, \ dim3(grid), dim3(kBlockTwoShot), 0, stream, A, B, N, \
num_blocks, rank, dbuffer_list, data_offset, \ num_blocks, rank, dbuffer_list, data_offset, \
flag_color); \ flag_color, this->kMaxProblemSize); \
} else if (world_size == 8) { \ } else if (world_size == 8) { \
using LineCodec = __codec<T, 8>; \ using LineCodec = __codec<T, 8>; \
using AllReduceKernel = AllReduceTwoshot<T, LineCodec, cast_bf2half>; \ using AllReduceKernel = AllReduceTwoshot<T, LineCodec, cast_bf2half>; \
hipLaunchKernelGGL((allreduce_prototype_twoshot<AllReduceKernel, T>), \ hipLaunchKernelGGL((allreduce_prototype_twoshot<AllReduceKernel, T>), \
dim3(grid), dim3(kBlockTwoShot), 0, stream, A, B, N, \ dim3(grid), dim3(kBlockTwoShot), 0, stream, A, B, N, \
num_blocks, rank, dbuffer_list, data_offset, \ num_blocks, rank, dbuffer_list, data_offset, \
flag_color); \ flag_color, this->kMaxProblemSize); \
} }
enum QuickReduceQuantLevel { enum QuickReduceQuantLevel {
......
...@@ -553,13 +553,12 @@ struct AllReduceTwoshot { ...@@ -553,13 +553,12 @@ struct AllReduceTwoshot {
int const rank, // rank index int const rank, // rank index
uint8_t** __restrict__ buffer_list, // communication buffers uint8_t** __restrict__ buffer_list, // communication buffers
uint32_t const data_offset, // offset to start of the data buffer uint32_t const data_offset, // offset to start of the data buffer
uint32_t flag_color) { uint32_t flag_color, int64_t data_size_per_phase) {
// Topology // Topology
int thread = threadIdx.x + threadIdx.y * kWavefront; int thread = threadIdx.x + threadIdx.y * kWavefront;
uint8_t* rank_buffer = buffer_list[rank]; uint8_t* rank_buffer = buffer_list[rank];
Codec codec(thread, rank); Codec codec(thread, rank);
int block_id = blockIdx.x; int block_id = blockIdx.x;
int grid_size = gridDim.x;
// -------------------------------------------------------- // --------------------------------------------------------
// Read input into registers // Read input into registers
int32x4_t tA[kAtoms]; int32x4_t tA[kAtoms];
...@@ -588,12 +587,10 @@ struct AllReduceTwoshot { ...@@ -588,12 +587,10 @@ struct AllReduceTwoshot {
// rank responsible for this segment. // rank responsible for this segment.
uint32_t comm_data0_offset = uint32_t comm_data0_offset =
data_offset + block_id * Codec::kTransmittedTileSize; data_offset + block_id * Codec::kTransmittedTileSize;
uint32_t comm_data1_offset = uint32_t comm_data1_offset = data_size_per_phase + comm_data0_offset;
grid_size * Codec::kTransmittedTileSize + comm_data0_offset;
uint32_t comm_flags0_offset = block_id * (kWorldSize * sizeof(uint32_t)); uint32_t comm_flags0_offset = block_id * (kWorldSize * sizeof(uint32_t));
uint32_t comm_flags1_offset = uint32_t comm_flags1_offset = (data_offset / 2) + comm_flags0_offset;
grid_size * (kWorldSize * sizeof(uint32_t)) + comm_flags0_offset;
for (int r = 0; r < kWorldSize; r++) { for (int r = 0; r < kWorldSize; r++) {
int32x4_t* send_buffer = int32x4_t* send_buffer =
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment