Commit 469e903b authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.8.2' into v0.8.2-dev

parents 389ebcf7 25f560a6
...@@ -7,8 +7,3 @@ inline constexpr uint32_t next_pow_2(uint32_t const num) { ...@@ -7,8 +7,3 @@ inline constexpr uint32_t next_pow_2(uint32_t const num) {
if (num <= 1) return num; if (num <= 1) return num;
return 1 << (CHAR_BIT * sizeof(num) - __builtin_clz(num - 1)); return 1 << (CHAR_BIT * sizeof(num) - __builtin_clz(num - 1));
} }
template <typename T>
inline constexpr std::enable_if_t<std::is_integral_v<T>, T> ceil_div(T a, T b) {
return (a + b - 1) / b;
}
\ No newline at end of file
...@@ -24,8 +24,8 @@ struct KernelVecType<float> { ...@@ -24,8 +24,8 @@ struct KernelVecType<float> {
template <> template <>
struct KernelVecType<c10::Half> { struct KernelVecType<c10::Half> {
#ifdef __powerpc64__ #if defined(__powerpc64__) || defined(__s390x__)
// Power architecture-specific vector types // Power and s390x architecture-specific vector types
using q_load_vec_type = vec_op::FP32Vec8; using q_load_vec_type = vec_op::FP32Vec8;
using k_load_vec_type = vec_op::FP32Vec16; using k_load_vec_type = vec_op::FP32Vec16;
using v_load_vec_type = vec_op::FP32Vec16; using v_load_vec_type = vec_op::FP32Vec16;
......
...@@ -3,6 +3,12 @@ ...@@ -3,6 +3,12 @@
#include "cpu_types.hpp" #include "cpu_types.hpp"
#if defined(__x86_64__)
#define DISPATCH_MACRO VLLM_DISPATCH_FLOATING_TYPES_WITH_E5M2
#else
#define DISPATCH_MACRO VLLM_DISPATCH_FLOATING_TYPES
#endif
namespace { namespace {
template <typename scalar_t> template <typename scalar_t>
void copy_blocks_cpu_impl(std::vector<torch::Tensor> const& key_caches, void copy_blocks_cpu_impl(std::vector<torch::Tensor> const& key_caches,
...@@ -95,13 +101,12 @@ void copy_blocks(std::vector<torch::Tensor> const& key_caches, ...@@ -95,13 +101,12 @@ void copy_blocks(std::vector<torch::Tensor> const& key_caches,
} }
const int element_num_per_block = key_caches[0][0].numel(); const int element_num_per_block = key_caches[0][0].numel();
VLLM_DISPATCH_FLOATING_TYPES( DISPATCH_MACRO(key_caches[0].scalar_type(), "copy_blocks_cpu_impl", [&] {
key_caches[0].scalar_type(), "copy_blocks_cpu_impl", [&] { CPU_KERNEL_GUARD_IN(copy_blocks_cpu_impl)
CPU_KERNEL_GUARD_IN(copy_blocks_cpu_impl) copy_blocks_cpu_impl<scalar_t>(key_caches, value_caches, block_mapping,
copy_blocks_cpu_impl<scalar_t>(key_caches, value_caches, block_mapping, element_num_per_block, num_layers);
element_num_per_block, num_layers); CPU_KERNEL_GUARD_OUT(copy_blocks_cpu_impl)
CPU_KERNEL_GUARD_OUT(copy_blocks_cpu_impl) });
});
} }
void reshape_and_cache(torch::Tensor& key, torch::Tensor& value, void reshape_and_cache(torch::Tensor& key, torch::Tensor& value,
...@@ -118,16 +123,15 @@ void reshape_and_cache(torch::Tensor& key, torch::Tensor& value, ...@@ -118,16 +123,15 @@ void reshape_and_cache(torch::Tensor& key, torch::Tensor& value,
int key_stride = key.stride(0); int key_stride = key.stride(0);
int value_stride = value.stride(0); int value_stride = value.stride(0);
VLLM_DISPATCH_FLOATING_TYPES( DISPATCH_MACRO(key.scalar_type(), "reshape_and_cache_cpu_impl", [&] {
key.scalar_type(), "reshape_and_cache_cpu_impl", [&] { CPU_KERNEL_GUARD_IN(reshape_and_cache_cpu_impl)
CPU_KERNEL_GUARD_IN(reshape_and_cache_cpu_impl) reshape_and_cache_cpu_impl<scalar_t>(
reshape_and_cache_cpu_impl<scalar_t>( key.data_ptr<scalar_t>(), value.data_ptr<scalar_t>(),
key.data_ptr<scalar_t>(), value.data_ptr<scalar_t>(), key_cache.data_ptr<scalar_t>(), value_cache.data_ptr<scalar_t>(),
key_cache.data_ptr<scalar_t>(), value_cache.data_ptr<scalar_t>(), slot_mapping.data_ptr<int64_t>(), num_tokens, key_stride, value_stride,
slot_mapping.data_ptr<int64_t>(), num_tokens, key_stride, num_heads, head_size, block_size, x);
value_stride, num_heads, head_size, block_size, x); CPU_KERNEL_GUARD_OUT(reshape_and_cache_cpu_impl)
CPU_KERNEL_GUARD_OUT(reshape_and_cache_cpu_impl) });
});
} }
void swap_blocks(torch::Tensor& src, torch::Tensor& dst, void swap_blocks(torch::Tensor& src, torch::Tensor& dst,
......
...@@ -7,6 +7,9 @@ ...@@ -7,6 +7,9 @@
#elif defined(__POWER9_VECTOR__) #elif defined(__POWER9_VECTOR__)
// ppc implementation // ppc implementation
#include "cpu_types_vsx.hpp" #include "cpu_types_vsx.hpp"
#elif defined(__s390x__)
// s390 implementation
#include "cpu_types_vxe.hpp"
#elif defined(__aarch64__) #elif defined(__aarch64__)
// arm implementation // arm implementation
#include "cpu_types_arm.hpp" #include "cpu_types_arm.hpp"
......
...@@ -2,6 +2,10 @@ ...@@ -2,6 +2,10 @@
#include <torch/all.h> #include <torch/all.h>
#include <cmath> #include <cmath>
#if defined(__APPLE__)
#include "omp.h"
#endif
namespace vec_op { namespace vec_op {
#ifdef ARM_BF16_SUPPORT #ifdef ARM_BF16_SUPPORT
......
#ifndef CPU_TYPES_VXE_HPP
#define CPU_TYPES_VXE_HPP
#include <vecintrin.h>
#include <cmath>
#include <torch/all.h>
namespace vec_op {
#define vec_neg(a) (-(a))
#define vec_add(a, b) ((a) + (b))
#define vec_sub(a, b) ((a) - (b))
#define vec_mul(a, b) ((a) * (b))
#define vec_div(a, b) ((a) / (b))
#define vec_sr(a, b) ((a) >> (b)) // Vector Shift Right Algebaic
#define vec_sl(a, b) ((a) << (b)) // Vector Shift Left
// FIXME: FP16 is not fully supported in Torch-CPU
#define VLLM_DISPATCH_CASE_FLOATING_TYPES(...) \
AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__)
#define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \
AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__))
#ifndef CPU_OP_GUARD
#define CPU_KERNEL_GUARD_IN(NAME)
#define CPU_KERNEL_GUARD_OUT(NAME)
#else
#define CPU_KERNEL_GUARD_IN(NAME) \
std::cout << #NAME << " invoked." << std::endl;
#define CPU_KERNEL_GUARD_OUT(NAME) \
std::cout << #NAME << " exit." << std::endl;
#endif
#define FORCE_INLINE __attribute__((always_inline)) inline
namespace {
template <typename T, T... indexes, typename F>
constexpr void unroll_loop_item(std::integer_sequence<T, indexes...>, F&& f) {
(f(std::integral_constant<T, indexes>{}), ...);
}
}; // namespace
template <typename T, T count, typename F,
typename = std::enable_if_t<std::is_invocable_v<F, T>>>
constexpr void unroll_loop(F&& f) {
unroll_loop_item(std::make_integer_sequence<T, count>{}, std::forward<F>(f));
}
template <typename T>
struct Vec {
constexpr static int get_elem_num() { return T::VEC_ELEM_NUM; }
};
typedef struct ss16x8x2_t {
__vector signed short val[2];
} ss16x8x2_t;
typedef struct ss16x8x4_t {
__vector signed short val[4];
} ss16x8x4_t;
typedef struct f32x4x2_t {
__vector float val[2];
} f32x4x2_t;
typedef struct f32x4x4_t {
__vector float val[4];
} f32x4x4_t;
struct FP32Vec8;
struct FP32Vec16;
struct BF16Vec8 : public Vec<BF16Vec8> {
constexpr static int VEC_ELEM_NUM = 8;
__vector signed short reg;
explicit BF16Vec8(const void* ptr) : reg(*(__vector signed short*)ptr) {}
explicit BF16Vec8(const FP32Vec8&);
void save(void* ptr) const {
*reinterpret_cast<__vector signed short*>(ptr) = reg;
}
};
struct BF16Vec16 : public Vec<BF16Vec16> {
constexpr static int VEC_ELEM_NUM = 16;
ss16x8x2_t reg;
explicit BF16Vec16(const void* ptr) {
// Load 256 bits in two parts
reg.val[0] = (__vector signed short)vec_xl(0, (signed short*)ptr);
reg.val[1] = (__vector signed short)vec_xl(16, (signed short*)ptr);
}
explicit BF16Vec16(const FP32Vec16&);
void save(void* ptr) const {
// Save 256 bits in two parts
vec_xst(reg.val[0], 0, (signed short*)ptr);
vec_xst(reg.val[1], 16, (signed short*)ptr);
}
};
const static __vector signed short zero = vec_splats((signed short)0);
struct BF16Vec32 : public Vec<BF16Vec32> {
constexpr static int VEC_ELEM_NUM = 32;
ss16x8x4_t reg;
explicit BF16Vec32(const void* ptr)
: reg(*reinterpret_cast<const ss16x8x4_t*>(ptr)) {}
explicit BF16Vec32(ss16x8x4_t data) : reg(data) {}
explicit BF16Vec32(const BF16Vec8& vec8_data)
: reg({vec8_data.reg, vec8_data.reg, vec8_data.reg, vec8_data.reg}) {}
void save(void* ptr) const { *reinterpret_cast<ss16x8x4_t*>(ptr) = reg; }
};
struct FP32Vec4 : public Vec<FP32Vec4> {
constexpr static int VEC_ELEM_NUM = 4;
union AliasReg {
__vector float reg;
float values[VEC_ELEM_NUM];
};
__vector float reg;
explicit FP32Vec4(float v) : reg(vec_splats(v)) {}
explicit FP32Vec4() : reg(vec_splats(0.0f)) {}
explicit FP32Vec4(const float* ptr) : reg(vec_xl(0, ptr)) {}
explicit FP32Vec4(__vector float data) : reg(data) {}
explicit FP32Vec4(const FP32Vec4& data) : reg(data.reg) {}
};
struct FP32Vec8 : public Vec<FP32Vec8> {
constexpr static int VEC_ELEM_NUM = 8;
union AliasReg {
f32x4x2_t reg;
float values[VEC_ELEM_NUM];
};
f32x4x2_t reg;
explicit FP32Vec8(float v) {
reg.val[0] = vec_splats(v);
reg.val[1] = vec_splats(v);
}
explicit FP32Vec8() {
reg.val[0] = vec_splats(0.0f);
reg.val[1] = vec_splats(0.0f);
}
explicit FP32Vec8(const float* ptr) {
reg.val[0] = vec_xl(0, ptr);
reg.val[1] = vec_xl(16, ptr);
}
explicit FP32Vec8(f32x4x2_t data) : reg(data) {}
explicit FP32Vec8(const FP32Vec8& data) {
reg.val[0] = data.reg.val[0];
reg.val[1] = data.reg.val[1];
}
explicit FP32Vec8(const BF16Vec8& v) {
reg.val[0] = (__vector float)vec_mergeh(zero, v.reg);
reg.val[1] = (__vector float)vec_mergel(zero, v.reg);
}
float reduce_sum() const {
AliasReg ar;
ar.reg = reg;
float result = 0;
unroll_loop<int, VEC_ELEM_NUM>(
[&result, &ar](int i) { result += ar.values[i]; });
return result;
}
FP32Vec8 exp() const {
// TODO: Vectorize this
AliasReg ar;
ar.reg = reg;
f32x4x4_t ret;
ret.val[0][0] = std::exp(ar.values[0]);
ret.val[0][1] = std::exp(ar.values[1]);
ret.val[0][2] = std::exp(ar.values[2]);
ret.val[0][3] = std::exp(ar.values[3]);
ret.val[1][0] = std::exp(ar.values[4]);
ret.val[1][1] = std::exp(ar.values[5]);
ret.val[1][2] = std::exp(ar.values[6]);
ret.val[1][3] = std::exp(ar.values[7]);
return FP32Vec8(f32x4x2_t({ret.val[0], ret.val[1]}));
}
FP32Vec8 tanh() const {
// TODO: Vectorize this
AliasReg ar;
ar.reg = reg;
f32x4x4_t ret;
ret.val[0][0] = std::tanh(ar.values[0]);
ret.val[0][1] = std::tanh(ar.values[1]);
ret.val[0][2] = std::tanh(ar.values[2]);
ret.val[0][3] = std::tanh(ar.values[3]);
ret.val[1][0] = std::tanh(ar.values[4]);
ret.val[1][1] = std::tanh(ar.values[5]);
ret.val[1][2] = std::tanh(ar.values[6]);
ret.val[1][3] = std::tanh(ar.values[7]);
return FP32Vec8(f32x4x2_t({ret.val[0], ret.val[1]}));
}
FP32Vec8 er() const {
// TODO: Vectorize this
AliasReg ar;
ar.reg = reg;
f32x4x4_t ret;
ret.val[0][0] = std::erf(ar.values[0]);
ret.val[0][1] = std::erf(ar.values[1]);
ret.val[0][2] = std::erf(ar.values[2]);
ret.val[0][3] = std::erf(ar.values[3]);
ret.val[1][0] = std::erf(ar.values[4]);
ret.val[1][1] = std::erf(ar.values[5]);
ret.val[1][2] = std::erf(ar.values[6]);
ret.val[1][3] = std::erf(ar.values[7]);
return FP32Vec8(f32x4x2_t({ret.val[0], ret.val[1]}));
}
FP32Vec8 operator*(const FP32Vec8& b) const {
return FP32Vec8(
{vec_mul(reg.val[0], b.reg.val[0]), vec_mul(reg.val[1], b.reg.val[1])});
}
FP32Vec8 operator+(const FP32Vec8& b) const {
return FP32Vec8(
{vec_add(reg.val[0], b.reg.val[0]), vec_add(reg.val[1], b.reg.val[1])});
}
FP32Vec8 operator-(const FP32Vec8& b) const {
return FP32Vec8(
{vec_sub(reg.val[0], b.reg.val[0]), vec_sub(reg.val[1], b.reg.val[1])});
}
FP32Vec8 operator/(const FP32Vec8& b) const {
return FP32Vec8(
{vec_div(reg.val[0], b.reg.val[0]), vec_div(reg.val[1], b.reg.val[1])});
}
void save(float* ptr) const {
vec_xst(reg.val[0], 0, ptr);
vec_xst(reg.val[1], 16, ptr);
}
};
struct FP32Vec16 : public Vec<FP32Vec16> {
constexpr static int VEC_ELEM_NUM = 16;
union AliasReg {
f32x4x4_t reg;
float values[VEC_ELEM_NUM];
};
f32x4x4_t reg;
explicit FP32Vec16(float v) {
reg.val[0] = vec_splats(v);
reg.val[1] = vec_splats(v);
reg.val[2] = vec_splats(v);
reg.val[3] = vec_splats(v);
}
explicit FP32Vec16() {
reg.val[0] = vec_splats(0.0f);
reg.val[1] = vec_splats(0.0f);
reg.val[2] = vec_splats(0.0f);
reg.val[3] = vec_splats(0.0f);
}
explicit FP32Vec16(const float* ptr) {
reg.val[0] = vec_xl(0, ptr);
reg.val[1] = vec_xl(16, ptr);
reg.val[2] = vec_xl(32, ptr);
reg.val[3] = vec_xl(48, ptr);
}
explicit FP32Vec16(f32x4x4_t data) : reg(data) {}
explicit FP32Vec16(const FP32Vec16& data) {
reg.val[0] = data.reg.val[0];
reg.val[1] = data.reg.val[1];
reg.val[2] = data.reg.val[2];
reg.val[3] = data.reg.val[3];
}
explicit FP32Vec16(const FP32Vec4& data) {
reg.val[0] = data.reg;
reg.val[1] = data.reg;
reg.val[2] = data.reg;
reg.val[3] = data.reg;
}
explicit FP32Vec16(const FP32Vec8& data) {
reg.val[0] = data.reg.val[0];
reg.val[1] = data.reg.val[1];
reg.val[2] = data.reg.val[0];
reg.val[3] = data.reg.val[1];
}
explicit FP32Vec16(const BF16Vec16& v) {
reg.val[0] = (__vector float)vec_mergeh(zero, v.reg.val[0]);
reg.val[1] = (__vector float)vec_mergel(zero, v.reg.val[0]);
reg.val[2] = (__vector float)vec_mergeh(zero, v.reg.val[1]);
reg.val[3] = (__vector float)vec_mergel(zero, v.reg.val[1]);
}
explicit FP32Vec16(const BF16Vec8& v) : FP32Vec16(FP32Vec8(v)) {}
FP32Vec16 operator*(const FP32Vec16& b) const {
return FP32Vec16(f32x4x4_t({vec_mul(reg.val[0], b.reg.val[0]),
vec_mul(reg.val[1], b.reg.val[1]),
vec_mul(reg.val[2], b.reg.val[2]),
vec_mul(reg.val[3], b.reg.val[3])}));
}
FP32Vec16 operator+(const FP32Vec16& b) const {
return FP32Vec16(f32x4x4_t({vec_add(reg.val[0], b.reg.val[0]),
vec_add(reg.val[1], b.reg.val[1]),
vec_add(reg.val[2], b.reg.val[2]),
vec_add(reg.val[3], b.reg.val[3])}));
}
FP32Vec16 operator-(const FP32Vec16& b) const {
return FP32Vec16(f32x4x4_t({vec_sub(reg.val[0], b.reg.val[0]),
vec_sub(reg.val[1], b.reg.val[1]),
vec_sub(reg.val[2], b.reg.val[2]),
vec_sub(reg.val[3], b.reg.val[3])}));
}
FP32Vec16 operator/(const FP32Vec16& b) const {
return FP32Vec16(f32x4x4_t({vec_div(reg.val[0], b.reg.val[0]),
vec_div(reg.val[1], b.reg.val[1]),
vec_div(reg.val[2], b.reg.val[2]),
vec_div(reg.val[3], b.reg.val[3])}));
}
float reduce_sum() const {
AliasReg ar;
ar.reg = reg;
float result = 0;
unroll_loop<int, VEC_ELEM_NUM>(
[&result, &ar](int i) { result += ar.values[i]; });
return result;
}
template <int group_size>
float reduce_sub_sum(int idx) {
static_assert(VEC_ELEM_NUM % group_size == 0);
AliasReg ar;
ar.reg = reg;
float result = 0;
const int start = idx * group_size;
unroll_loop<int, group_size>(
[&result, &start, ar](int i) { result += ar.values[start + i]; });
return result;
}
void save(float* ptr) const {
vec_xst(reg.val[0], 0, ptr);
vec_xst(reg.val[1], 16, ptr);
vec_xst(reg.val[2], 32, ptr);
vec_xst(reg.val[3], 48, ptr);
}
};
template <typename T>
struct VecType {
using vec_type = void;
};
template <typename T>
using vec_t = typename VecType<T>::vec_type;
template <>
struct VecType<float> {
using vec_type = FP32Vec8;
};
template <>
struct VecType<c10::BFloat16> {
using vec_type = BF16Vec8;
};
template <typename T>
void storeFP32(float v, T* ptr) {
*ptr = v;
}
inline void fma(FP32Vec16& acc, FP32Vec16& a, FP32Vec16& b) {
acc = acc + a * b;
}
namespace c10 {
struct BFloat16 {
uint16_t value; // Assume BFloat16 is defined as a struct containing a 16-bit
// value.
};
} // namespace c10
template <>
inline void storeFP32<c10::BFloat16>(float v, c10::BFloat16* ptr) {
c10::BFloat16 __attribute__((__may_alias__))* v_ptr =
reinterpret_cast<c10::BFloat16*>(&v);
*ptr = *(v_ptr + 1);
}
#ifndef __VEC_CLASS_FP_NAN
#define __VEC_CLASS_FP_NAN (1 << 6)
#endif
const static __vector unsigned char omask = {2, 3, 6, 7, 10, 11, 14, 15,
18, 19, 22, 23, 26, 27, 30, 31};
const static __vector unsigned int bias = {0x00007fff, 0x00007fff, 0x00007fff,
0x00007fff};
const static __vector unsigned int nan = {0x7fc00000, 0x7fc00000, 0x7fc00000,
0x7fc00000};
const static __vector unsigned int sh16 = {16, 16, 16, 16};
const static __vector unsigned int one = {1, 1, 1, 1};
inline BF16Vec8::BF16Vec8(const FP32Vec8& v) {
__vector unsigned int inp0 = (__vector unsigned int)(v.reg.val[0]);
__vector unsigned int inp1 = (__vector unsigned int)(v.reg.val[1]);
int cc;
__vector __bool int sel0 =
vec_fp_test_data_class(v.reg.val[0], __VEC_CLASS_FP_NAN, &cc);
__vector __bool int sel1 =
vec_fp_test_data_class(v.reg.val[1], __VEC_CLASS_FP_NAN, &cc);
inp0 = vec_sel(inp0, nan, sel0) >> sh16;
inp1 = vec_sel(inp1, nan, sel1) >> sh16;
reg = (__vector signed short)vec_perm(inp0, inp1, omask);
}
inline BF16Vec16::BF16Vec16(const FP32Vec16& v) {
__vector unsigned int inp0 = (__vector unsigned int)(v.reg.val[0]);
__vector unsigned int inp1 = (__vector unsigned int)(v.reg.val[1]);
__vector unsigned int inp2 = (__vector unsigned int)(v.reg.val[2]);
__vector unsigned int inp3 = (__vector unsigned int)(v.reg.val[3]);
int cc;
__vector __bool int sel0 =
vec_fp_test_data_class(v.reg.val[0], __VEC_CLASS_FP_NAN, &cc);
__vector __bool int sel1 =
vec_fp_test_data_class(v.reg.val[1], __VEC_CLASS_FP_NAN, &cc);
__vector __bool int sel2 =
vec_fp_test_data_class(v.reg.val[2], __VEC_CLASS_FP_NAN, &cc);
__vector __bool int sel3 =
vec_fp_test_data_class(v.reg.val[3], __VEC_CLASS_FP_NAN, &cc);
inp0 = vec_sel(inp0, nan, sel0) >> sh16;
inp1 = vec_sel(inp1, nan, sel1) >> sh16;
inp2 = vec_sel(inp2, nan, sel2) >> sh16;
inp3 = vec_sel(inp3, nan, sel3) >> sh16;
reg.val[0] = (__vector signed short)vec_perm(inp0, inp1, omask);
reg.val[1] = (__vector signed short)vec_perm(inp2, inp3, omask);
}
inline void prefetch(const void* addr) { void __dcbt(const void* addr); }
}; // namespace vec_op
#endif
\ No newline at end of file
...@@ -16,9 +16,18 @@ namespace vec_op { ...@@ -16,9 +16,18 @@ namespace vec_op {
AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) \ AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__)
#define VLLM_DISPATCH_CASE_FLOATING_TYPES_FP8(...) \
AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Float8_e5m2, __VA_ARGS__)
#define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \ #define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \
AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__)) AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__))
#define VLLM_DISPATCH_FLOATING_TYPES_WITH_E5M2(TYPE, NAME, ...) \
AT_DISPATCH_SWITCH(TYPE, NAME, \
VLLM_DISPATCH_CASE_FLOATING_TYPES_FP8(__VA_ARGS__))
#ifndef CPU_OP_GUARD #ifndef CPU_OP_GUARD
#define CPU_KERNEL_GUARD_IN(NAME) #define CPU_KERNEL_GUARD_IN(NAME)
#define CPU_KERNEL_GUARD_OUT(NAME) #define CPU_KERNEL_GUARD_OUT(NAME)
......
...@@ -170,7 +170,7 @@ void rotary_embedding_gptj_impl( ...@@ -170,7 +170,7 @@ void rotary_embedding_gptj_impl(
void rotary_embedding(torch::Tensor& positions, torch::Tensor& query, void rotary_embedding(torch::Tensor& positions, torch::Tensor& query,
torch::Tensor& key, int64_t head_size, torch::Tensor& key, int64_t head_size,
torch::Tensor& cos_sin_cache, bool is_neox) { torch::Tensor& cos_sin_cache, bool is_neox) {
int num_tokens = query.numel() / query.size(-1); int num_tokens = positions.numel();
int rot_dim = cos_sin_cache.size(1); int rot_dim = cos_sin_cache.size(1);
int num_heads = query.size(-1) / head_size; int num_heads = query.size(-1) / head_size;
int num_kv_heads = key.size(-1) / head_size; int num_kv_heads = key.size(-1) / head_size;
......
...@@ -25,7 +25,7 @@ struct KernelVecType<c10::BFloat16> { ...@@ -25,7 +25,7 @@ struct KernelVecType<c10::BFloat16> {
template <> template <>
struct KernelVecType<c10::Half> { struct KernelVecType<c10::Half> {
#ifdef __powerpc64__ #if defined(__powerpc64__) || defined(__s390x__)
// Power architecture-specific vector type // Power architecture-specific vector type
using load_vec_type = vec_op::FP32Vec16; using load_vec_type = vec_op::FP32Vec16;
#else #else
......
...@@ -2,10 +2,14 @@ ...@@ -2,10 +2,14 @@
#include <stdio.h> #include <stdio.h>
#if defined(__CUDACC__) || defined(_NVHPC_CUDA) #if defined(__HIPCC__)
#define HOST_DEVICE_INLINE __forceinline__ __host__ __device__ #define HOST_DEVICE_INLINE __host__ __device__
#define DEVICE_INLINE __forceinline__ __device__ #define DEVICE_INLINE __device__
#define HOST_INLINE __forceinline__ __host__ #define HOST_INLINE __host__
#elif defined(__CUDACC__) || defined(_NVHPC_CUDA)
#define HOST_DEVICE_INLINE __host__ __device__ __forceinline__
#define DEVICE_INLINE __device__ __forceinline__
#define HOST_INLINE __host__ __forceinline__
#else #else
#define HOST_DEVICE_INLINE inline #define HOST_DEVICE_INLINE inline
#define DEVICE_INLINE inline #define DEVICE_INLINE inline
...@@ -25,3 +29,13 @@ ...@@ -25,3 +29,13 @@
int64_t get_device_attribute(int64_t attribute, int64_t device_id); int64_t get_device_attribute(int64_t attribute, int64_t device_id);
int64_t get_max_shared_memory_per_block_device_attribute(int64_t device_id); int64_t get_max_shared_memory_per_block_device_attribute(int64_t device_id);
namespace cuda_utils {
template <typename T>
HOST_DEVICE_INLINE constexpr std::enable_if_t<std::is_integral_v<T>, T>
ceil_div(T a, T b) {
return (a + b - 1) / b;
}
}; // namespace cuda_utils
\ No newline at end of file
...@@ -142,3 +142,44 @@ void register_graph_buffers(fptr_t _fa, ...@@ -142,3 +142,44 @@ void register_graph_buffers(fptr_t _fa,
bytes.reserve(handles.size()); bytes.reserve(handles.size());
fa->register_graph_buffers(bytes, offsets); fa->register_graph_buffers(bytes, offsets);
} }
std::tuple<fptr_t, torch::Tensor> allocate_shared_buffer_and_handle(
int64_t size) {
auto device_index = c10::cuda::current_device();
at::DeviceGuard device_guard(at::Device(at::DeviceType::CUDA, device_index));
void* buffer;
cudaStreamCaptureMode mode = cudaStreamCaptureModeRelaxed;
auto stream = c10::cuda::getCurrentCUDAStream().stream();
AT_CUDA_CHECK(cudaThreadExchangeStreamCaptureMode(&mode));
#if defined(USE_ROCM)
// data buffers need to be "uncached" for signal on MI200
AT_CUDA_CHECK(
hipExtMallocWithFlags((void**)&buffer, size, hipDeviceMallocUncached));
#else
AT_CUDA_CHECK(cudaMalloc((void**)&buffer, size));
#endif
AT_CUDA_CHECK(cudaMemsetAsync(buffer, 0, size, stream));
AT_CUDA_CHECK(cudaStreamSynchronize(stream));
AT_CUDA_CHECK(cudaThreadExchangeStreamCaptureMode(&mode));
auto options =
torch::TensorOptions().dtype(torch::kUInt8).device(torch::kCPU);
auto handle =
torch::empty({static_cast<int64_t>(sizeof(cudaIpcMemHandle_t))}, options);
AT_CUDA_CHECK(
cudaIpcGetMemHandle((cudaIpcMemHandle_t*)handle.data_ptr(), buffer));
return std::make_tuple(reinterpret_cast<fptr_t>(buffer), handle);
}
fptr_t open_mem_handle(torch::Tensor& mem_handle) {
void* ipc_ptr;
AT_CUDA_CHECK(cudaIpcOpenMemHandle(
(void**)&ipc_ptr, *((const cudaIpcMemHandle_t*)mem_handle.data_ptr()),
cudaIpcMemLazyEnablePeerAccess));
return reinterpret_cast<fptr_t>(ipc_ptr);
}
void free_shared_buffer(fptr_t buffer) {
AT_CUDA_CHECK(cudaFree(reinterpret_cast<void*>(buffer)));
}
...@@ -5,6 +5,10 @@ ...@@ -5,6 +5,10 @@
#include <cuda_fp16.h> #include <cuda_fp16.h>
#include <cuda_runtime.h> #include <cuda_runtime.h>
#if defined(USE_ROCM)
typedef __hip_bfloat16 nv_bfloat16;
#endif
#include <iostream> #include <iostream>
#include <array> #include <array>
#include <limits> #include <limits>
...@@ -12,6 +16,7 @@ ...@@ -12,6 +16,7 @@
#include <unordered_map> #include <unordered_map>
#include <vector> #include <vector>
namespace vllm {
#define CUDACHECK(cmd) \ #define CUDACHECK(cmd) \
do { \ do { \
cudaError_t e = cmd; \ cudaError_t e = cmd; \
...@@ -22,24 +27,37 @@ ...@@ -22,24 +27,37 @@
} \ } \
} while (0) } while (0)
namespace vllm { // Maximal number of blocks in allreduce kernel.
constexpr int kMaxBlocks = 36; constexpr int kMaxBlocks = 36;
// Default number of blocks in allreduce kernel.
#ifndef USE_ROCM
const int defaultBlockLimit = 36;
CUpointer_attribute rangeStartAddrAttr =
CUDA_POINTER_ATTRIBUTE_RANGE_START_ADDR;
#else
const int defaultBlockLimit = 16;
hipPointer_attribute rangeStartAddrAttr =
HIP_POINTER_ATTRIBUTE_RANGE_START_ADDR;
#endif
// Counter may overflow, but it's fine since unsigned int overflow is // Counter may overflow, but it's fine since unsigned int overflow is
// well-defined behavior. // well-defined behavior.
using FlagType = uint32_t; using FlagType = uint32_t;
// Two sets of peer counters are needed for two syncs. The reason is that
// it's possible for peer GPU block to arrive at the second sync point while
// the current GPU block haven't passed the first sync point. Thus, peer GPU
// may write counter+1 while current GPU is busy waiting for counter. We use
// alternating counter array to avoid this possibility.
struct Signal { struct Signal {
alignas(128) FlagType self_counter[kMaxBlocks][8]; alignas(128) FlagType start[kMaxBlocks][8];
// Two sets of peer counters are needed for two syncs. The reason is that alignas(128) FlagType end[kMaxBlocks][8];
// it's possible for peer GPU block to arrive at the second sync point while alignas(128) FlagType _flag[kMaxBlocks]; // incremental flags for each rank
// the current GPU block haven't passed the first sync point. Thus, peer GPU
// may write counter+1 while current GPU is busy waiting for counter. We use
// alternating counter array to avoid this possibility.
alignas(128) FlagType peer_counter[2][kMaxBlocks][8];
}; };
struct __align__(16) RankData { struct __align__(16) RankData {
const void* __restrict__ ptrs[8]; const void* ptrs[8];
}; };
struct __align__(16) RankSignals { struct __align__(16) RankSignals {
...@@ -134,27 +152,29 @@ DINLINE O downcast(array_t<float, O::size> val) { ...@@ -134,27 +152,29 @@ DINLINE O downcast(array_t<float, O::size> val) {
} }
} }
#if !defined(USE_ROCM)
static DINLINE void st_flag_release(FlagType* flag_addr, FlagType flag) { static DINLINE void st_flag_release(FlagType* flag_addr, FlagType flag) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700 #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700
asm volatile("st.release.sys.global.u32 [%1], %0;" ::"r"(flag), asm volatile("st.release.sys.global.u32 [%1], %0;" ::"r"(flag),
"l"(flag_addr)); "l"(flag_addr));
#else #else
asm volatile("membar.sys; st.volatile.global.u32 [%1], %0;" ::"r"(flag), asm volatile("membar.sys; st.volatile.global.u32 [%1], %0;" ::"r"(flag),
"l"(flag_addr)); "l"(flag_addr));
#endif #endif
} }
static DINLINE FlagType ld_flag_acquire(FlagType* flag_addr) { static DINLINE FlagType ld_flag_acquire(FlagType* flag_addr) {
FlagType flag; FlagType flag;
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700 #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700
asm volatile("ld.acquire.sys.global.u32 %0, [%1];" asm volatile("ld.acquire.sys.global.u32 %0, [%1];"
: "=r"(flag) : "=r"(flag)
: "l"(flag_addr)); : "l"(flag_addr));
#else #else
asm volatile("ld.volatile.global.u32 %0, [%1]; membar.gl;" asm volatile("ld.volatile.global.u32 %0, [%1]; membar.gl;"
: "=r"(flag) : "=r"(flag)
: "l"(flag_addr)); : "l"(flag_addr));
#endif #endif
return flag; return flag;
} }
...@@ -170,37 +190,108 @@ static DINLINE FlagType ld_flag_volatile(FlagType* flag_addr) { ...@@ -170,37 +190,108 @@ static DINLINE FlagType ld_flag_volatile(FlagType* flag_addr) {
return flag; return flag;
} }
// is_start: whether this is the very first synchronization barrier. // This function is meant to be used as the first synchronization in the all
// need_fence: whether a memory fence is needed. If true, a release-acquire // reduce kernel. Thus, it doesn't need to make any visibility guarantees for
// semantic is used to enforce memory access order before and after this // prior memory accesses. Note: volatile writes will not be reordered against
// barrier. // other volatile writes.
template <int ngpus, bool is_start, bool need_fence = false> template <int ngpus>
DINLINE void multi_gpu_barrier(const RankSignals& sg, Signal* self_sg, DINLINE void start_sync(const RankSignals& sg, Signal* self_sg, int rank) {
int rank) { uint32_t flag = self_sg->_flag[blockIdx.x] + 1;
if constexpr (!is_start) __syncthreads();
static_assert(
!(is_start && need_fence)); // Start barrier shouldn't need fence.
if (threadIdx.x < ngpus) { if (threadIdx.x < ngpus) {
// Increment the counter. Technically we only need one counter, but we use auto peer_counter_ptr = &sg.signals[threadIdx.x]->start[blockIdx.x][rank];
// multiple per block to eliminate the need to share the counter via smem. auto self_counter_ptr = &self_sg->start[blockIdx.x][threadIdx.x];
auto val = self_sg->self_counter[blockIdx.x][threadIdx.x] += 1; // Write the expected counter value to peer and wait for correct value
// from peer.
st_flag_volatile(peer_counter_ptr, flag);
while (ld_flag_volatile(self_counter_ptr) != flag);
}
__syncthreads();
// use one thread to update flag
if (threadIdx.x == 0) self_sg->_flag[blockIdx.x] = flag;
}
// This function is meant to be used as the second or the final
// synchronization barrier in the all reduce kernel. If it's the final
// synchronization barrier, we don't need to make any visibility guarantees
// for prior memory accesses.
template <int ngpus, bool final_sync = false>
DINLINE void end_sync(const RankSignals& sg, Signal* self_sg, int rank) {
__syncthreads();
uint32_t flag = self_sg->_flag[blockIdx.x] + 1;
if (threadIdx.x < ngpus) {
auto peer_counter_ptr = &sg.signals[threadIdx.x]->end[blockIdx.x][rank];
auto self_counter_ptr = &self_sg->end[blockIdx.x][threadIdx.x];
// Write the expected counter value to peer and wait for correct value from // Write the expected counter value to peer and wait for correct value from
// peer. // peer.
auto peer_counter_ptr = if constexpr (!final_sync) {
&sg.signals[threadIdx.x]->peer_counter[val % 2][blockIdx.x][rank]; st_flag_release(peer_counter_ptr, flag);
auto self_counter_ptr = while (ld_flag_acquire(self_counter_ptr) != flag);
&self_sg->peer_counter[val % 2][blockIdx.x][threadIdx.x];
if constexpr (need_fence) {
st_flag_release(peer_counter_ptr, val);
while (ld_flag_acquire(self_counter_ptr) != val);
} else { } else {
st_flag_volatile(peer_counter_ptr, val); st_flag_volatile(peer_counter_ptr, flag);
while (ld_flag_volatile(self_counter_ptr) != val); while (ld_flag_volatile(self_counter_ptr) != flag);
} }
} }
if constexpr (is_start || need_fence) __syncthreads(); if constexpr (!final_sync) __syncthreads();
// use one thread to update flag
if (threadIdx.x == 0) self_sg->_flag[blockIdx.x] = flag;
}
#else
template <int ngpus>
DINLINE void start_sync(const RankSignals& sg, Signal* self_sg, int rank) {
uint32_t flag = self_sg->_flag[blockIdx.x] + 1;
if (threadIdx.x < ngpus) {
// simultaneously write to the corresponding flag of all ranks.
// Latency = 1 p2p write
// __scoped_atomic_store_n(&sg.signals[threadIdx.x]->start[blockIdx.x][rank],
// flag, __ATOMIC_RELAXED, __MEMORY_SCOPE_SYSTEM);
// // wait until we got true from all ranks
// while (__scoped_atomic_load_n(&self_sg->start[blockIdx.x][threadIdx.x],
// __ATOMIC_RELAXED,
// __MEMORY_SCOPE_DEVICE) < flag);
__atomic_store_n(&sg.signals[threadIdx.x]->start[blockIdx.x][rank], flag,
__ATOMIC_RELAXED);
// wait until we got true from all ranks
while (__atomic_load_n(&self_sg->start[blockIdx.x][threadIdx.x],
__ATOMIC_RELAXED) < flag);
}
__syncthreads();
// use one thread to update flag
if (threadIdx.x == 0) self_sg->_flag[blockIdx.x] = flag;
}
template <int ngpus, bool final_sync = false>
DINLINE void end_sync(const RankSignals& sg, Signal* self_sg, int rank) {
__syncthreads();
uint32_t flag = self_sg->_flag[blockIdx.x] + 1;
if (threadIdx.x < ngpus) {
// simultaneously write to the corresponding flag of all ranks.
// Latency = 1 p2p write
// __scoped_atomic_store_n(&sg.signals[threadIdx.x]->end[blockIdx.x][rank],
// flag,
// final_sync ? __ATOMIC_RELAXED : __ATOMIC_RELEASE,
// __MEMORY_SCOPE_SYSTEM);
// // wait until we got true from all ranks
// while (
// __scoped_atomic_load_n(&self_sg->end[blockIdx.x][threadIdx.x],
// final_sync ? __ATOMIC_RELAXED : __ATOMIC_ACQUIRE,
// __MEMORY_SCOPE_DEVICE) < flag);
__atomic_store_n(&sg.signals[threadIdx.x]->end[blockIdx.x][rank], flag,
final_sync ? __ATOMIC_RELAXED : __ATOMIC_RELEASE);
// wait until we got true from all ranks
while (__atomic_load_n(&self_sg->end[blockIdx.x][threadIdx.x],
final_sync ? __ATOMIC_RELAXED : __ATOMIC_ACQUIRE) <
flag);
}
if constexpr (!final_sync) __syncthreads();
// use one thread to update flag
if (threadIdx.x == 0) self_sg->_flag[blockIdx.x] = flag;
} }
#endif
template <typename P, int ngpus, typename A> template <typename P, int ngpus, typename A>
DINLINE P packed_reduce(const P* ptrs[], int idx) { DINLINE P packed_reduce(const P* ptrs[], int idx) {
A tmp = upcast(ptrs[0][idx]); A tmp = upcast(ptrs[0][idx]);
...@@ -220,13 +311,13 @@ __global__ void __launch_bounds__(512, 1) ...@@ -220,13 +311,13 @@ __global__ void __launch_bounds__(512, 1)
// note: we don't reorder the address so the accumulation order is the same // note: we don't reorder the address so the accumulation order is the same
// for all ranks, ensuring bitwise identical results // for all ranks, ensuring bitwise identical results
auto dp = *_dp; auto dp = *_dp;
multi_gpu_barrier<ngpus, true>(sg, self_sg, rank); start_sync<ngpus>(sg, self_sg, rank);
// do the actual reduction // do the actual reduction
for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size; for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size;
idx += gridDim.x * blockDim.x) { idx += gridDim.x * blockDim.x) {
((P*)result)[idx] = packed_reduce<P, ngpus, A>((const P**)&dp.ptrs[0], idx); ((P*)result)[idx] = packed_reduce<P, ngpus, A>((const P**)&dp.ptrs[0], idx);
} }
multi_gpu_barrier<ngpus, false>(sg, self_sg, rank); end_sync<ngpus, true>(sg, self_sg, rank);
} }
template <typename P> template <typename P>
...@@ -255,12 +346,13 @@ __global__ void __launch_bounds__(512, 1) ...@@ -255,12 +346,13 @@ __global__ void __launch_bounds__(512, 1)
tmps[i] = get_tmp_buf<P>(sg.signals[target]); tmps[i] = get_tmp_buf<P>(sg.signals[target]);
} }
auto tmp_out = tmps[0]; auto tmp_out = tmps[0];
multi_gpu_barrier<ngpus, true>(sg, self_sg, rank); start_sync<ngpus>(sg, self_sg, rank);
// stage 1: reduce scatter // stage 1: reduce scatter
for (int idx = start + tid; idx < end; idx += stride) { for (int idx = start + tid; idx < end; idx += stride) {
tmp_out[idx - start] = packed_reduce<P, ngpus, A>(ptrs, idx); tmp_out[idx - start] = packed_reduce<P, ngpus, A>(ptrs, idx);
} }
multi_gpu_barrier<ngpus, false, true>(sg, self_sg, rank); end_sync<ngpus>(sg, self_sg, rank);
// stage 2: allgather. Note: it's important to match the tid between // stage 2: allgather. Note: it's important to match the tid between
// the two stages, because visibility across devices is only guaranteed // the two stages, because visibility across devices is only guaranteed
...@@ -290,7 +382,7 @@ class CustomAllreduce { ...@@ -290,7 +382,7 @@ class CustomAllreduce {
bool full_nvlink_; bool full_nvlink_;
RankSignals sg_; RankSignals sg_;
// Stores an map from a pointer to its peer pointters from all ranks. // Stores an map from a pointer to its peer pointers from all ranks.
std::unordered_map<void*, RankData*> buffers_; std::unordered_map<void*, RankData*> buffers_;
Signal* self_sg_; Signal* self_sg_;
...@@ -361,8 +453,7 @@ class CustomAllreduce { ...@@ -361,8 +453,7 @@ class CustomAllreduce {
void* base_ptr; void* base_ptr;
// note: must share the base address of each allocation, or we get wrong // note: must share the base address of each allocation, or we get wrong
// address // address
if (cuPointerGetAttribute(&base_ptr, if (cuPointerGetAttribute(&base_ptr, rangeStartAddrAttr,
CU_POINTER_ATTRIBUTE_RANGE_START_ADDR,
(CUdeviceptr)ptr) != CUDA_SUCCESS) (CUdeviceptr)ptr) != CUDA_SUCCESS)
throw std::runtime_error("failed to get pointer attr"); throw std::runtime_error("failed to get pointer attr");
CUDACHECK(cudaIpcGetMemHandle( CUDACHECK(cudaIpcGetMemHandle(
...@@ -439,7 +530,7 @@ class CustomAllreduce { ...@@ -439,7 +530,7 @@ class CustomAllreduce {
*/ */
template <typename T> template <typename T>
void allreduce(cudaStream_t stream, T* input, T* output, int size, void allreduce(cudaStream_t stream, T* input, T* output, int size,
int threads = 512, int block_limit = 36) { int threads = 512, int block_limit = defaultBlockLimit) {
auto d = packed_t<T>::P::size; auto d = packed_t<T>::P::size;
if (size % d != 0) if (size % d != 0)
throw std::runtime_error( throw std::runtime_error(
...@@ -473,8 +564,6 @@ class CustomAllreduce { ...@@ -473,8 +564,6 @@ class CustomAllreduce {
#define KL(ngpus, name) \ #define KL(ngpus, name) \
name<T, ngpus><<<blocks, threads, 0, stream>>>(ptrs, sg_, self_sg_, output, \ name<T, ngpus><<<blocks, threads, 0, stream>>>(ptrs, sg_, self_sg_, output, \
rank_, size); rank_, size);
// TODO(hanzhi713): Threshold is different for A100 and H100.
// Add per device threshold.
#define REDUCE_CASE(ngpus) \ #define REDUCE_CASE(ngpus) \
case ngpus: { \ case ngpus: { \
if (world_size_ == 2) { \ if (world_size_ == 2) { \
...@@ -497,7 +586,8 @@ class CustomAllreduce { ...@@ -497,7 +586,8 @@ class CustomAllreduce {
REDUCE_CASE(8) REDUCE_CASE(8)
default: default:
throw std::runtime_error( throw std::runtime_error(
"custom allreduce only supports num gpus in (2,4,6,8). Actual num " "custom allreduce only supports num gpus in (2,4,6,8). Actual "
"num "
"gpus = " + "gpus = " +
std::to_string(world_size_)); std::to_string(world_size_));
} }
......
...@@ -20,9 +20,16 @@ ...@@ -20,9 +20,16 @@
#include <vector> #include <vector>
#include "cuda_profiler_api.h" #include "cuda_profiler_api.h"
#include "custom_all_reduce.cuh"
#include "mpi.h" #include "mpi.h"
#include "nccl.h" #ifdef USE_ROCM
#include <hip/hip_bf16.h>
typedef __hip_bfloat16 nv_bfloat16;
#include "rccl/rccl.h"
#include "custom_all_reduce_hip.cuh"
#else
#include "nccl.h"
#include "custom_all_reduce.cuh"
#endif
#define MPICHECK(cmd) \ #define MPICHECK(cmd) \
do { \ do { \
...@@ -44,13 +51,16 @@ ...@@ -44,13 +51,16 @@
} while (0) } while (0)
__global__ void dummy_kernel() { __global__ void dummy_kernel() {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700 #ifdef USE_ROCM
for (int i = 0; i < 100; i++) __nanosleep(1000000); // 100ms
#else
for (int i = 0; i < 100; i++) { for (int i = 0; i < 100; i++) {
long long int start = clock64(); uint64_t start = wall_clock64();
while (clock64() - start < 150000000); // approximately 98.4ms on P40 uint64_t cycles_elapsed;
do {
cycles_elapsed = wall_clock64() - start;
} while (cycles_elapsed < 100);
} }
#else
for (int i = 0; i < 100; i++) __nanosleep(1000000); // 100ms
#endif #endif
} }
...@@ -121,8 +131,14 @@ void run(int myRank, int nRanks, ncclComm_t& comm, int threads, int block_limit, ...@@ -121,8 +131,14 @@ void run(int myRank, int nRanks, ncclComm_t& comm, int threads, int block_limit,
* registration, they are allocated and registered together in the test for * registration, they are allocated and registered together in the test for
* convenience. * convenience.
*/ */
#ifdef USE_ROCM
CUDACHECK(hipExtMallocWithFlags(
(void**)&buffer, 2 * data_size * sizeof(T) + sizeof(vllm::Signal),
hipDeviceMallocUncached));
#else
CUDACHECK( CUDACHECK(
cudaMalloc(&buffer, 2 * data_size * sizeof(T) + sizeof(vllm::Signal))); cudaMalloc(&buffer, 2 * data_size * sizeof(T) + sizeof(vllm::Signal)));
#endif
CUDACHECK( CUDACHECK(
cudaMemset(buffer, 0, 2 * data_size * sizeof(T) + sizeof(vllm::Signal))); cudaMemset(buffer, 0, 2 * data_size * sizeof(T) + sizeof(vllm::Signal)));
CUDACHECK(cudaMalloc(&self_data_copy, data_size * sizeof(T))); CUDACHECK(cudaMalloc(&self_data_copy, data_size * sizeof(T)));
...@@ -135,26 +151,24 @@ void run(int myRank, int nRanks, ncclComm_t& comm, int threads, int block_limit, ...@@ -135,26 +151,24 @@ void run(int myRank, int nRanks, ncclComm_t& comm, int threads, int block_limit,
void* rank_data; void* rank_data;
size_t rank_data_sz = 16 * 1024 * 1024; size_t rank_data_sz = 16 * 1024 * 1024;
CUDACHECK(cudaMalloc(&rank_data, rank_data_sz)); CUDACHECK(cudaMalloc(&rank_data, rank_data_sz));
vllm::Signal* ipc_ptrs[8]; std::vector<int64_t> offsets(nRanks, 0);
for (int i = 0; i < nRanks; i++) { vllm::CustomAllreduce fa(buffer, rank_data, rank_data_sz, data_handles,
if (i == myRank) offsets, myRank);
ipc_ptrs[i] = buffer;
else
CUDACHECK(cudaIpcOpenMemHandle((void**)&ipc_ptrs[i], data_handles[i],
cudaIpcMemLazyEnablePeerAccess));
}
vllm::CustomAllreduce fa(ipc_ptrs, rank_data, rank_data_sz, myRank, nRanks);
auto* self_data = auto* self_data =
reinterpret_cast<T*>(reinterpret_cast<char*>(buffer) + reinterpret_cast<T*>(reinterpret_cast<char*>(buffer) +
sizeof(vllm::Signal) + data_size * sizeof(T)); sizeof(vllm::Signal) + data_size * sizeof(T));
// hack buffer registration // hack buffer registration
{ {
void* data[8]; std::vector<std::string> handles;
handles.reserve(nRanks);
for (int i = 0; i < nRanks; i++) { for (int i = 0; i < nRanks; i++) {
data[i] = char* begin = (char*)&data_handles[i];
((char*)ipc_ptrs[i]) + sizeof(vllm::Signal) + data_size * sizeof(T); char* end = (char*)&data_handles[i + 1];
handles.emplace_back(begin, end);
} }
fa.register_buffer(data); std::vector<int64_t> offsets(nRanks,
sizeof(vllm::Signal) + data_size * sizeof(T));
fa.register_buffer(handles, offsets, self_data);
} }
double* ground_truth; double* ground_truth;
...@@ -266,14 +280,14 @@ void run(int myRank, int nRanks, ncclComm_t& comm, int threads, int block_limit, ...@@ -266,14 +280,14 @@ void run(int myRank, int nRanks, ncclComm_t& comm, int threads, int block_limit,
if (diff >= 4e-2) { if (diff >= 4e-2) {
printf( printf(
"Rank %d: Verification mismatch at %lld: %f != (my) %f, gt=%f\n", "Rank %d: Verification mismatch at %lld: %f != (my) %f, gt=%f\n",
myRank, j, nccl_result[j], my_result[j], ground_truth[j]); myRank, j, nccl_result[j], my_result[j], ground_truth[j]);
break; break;
} }
} }
} }
if (myRank == 0) if (myRank == 0)
printf("Test passed: nGPUs:%d, sz (kb): %d, %d, %d\n", nRanks, printf("Test passed: nGPUs:%d, sz (kb): %d, %d, %d\n", nRanks,
data_size * sizeof(T) / 1024, threads, block_limit); data_size * sizeof(T) / 1024, threads, block_limit);
// long double nccl_diffs = 0.0; // long double nccl_diffs = 0.0;
// long double my_diffs = 0.0; // long double my_diffs = 0.0;
// for (int j = 0; j < data_size; j++) { // for (int j = 0; j < data_size; j++) {
...@@ -306,24 +320,27 @@ int main(int argc, char** argv) { ...@@ -306,24 +320,27 @@ int main(int argc, char** argv) {
ncclComm_t comm; ncclComm_t comm;
if (myRank == 0) ncclGetUniqueId(&id); if (myRank == 0) ncclGetUniqueId(&id);
MPICHECK(MPI_Bcast(static_cast<void*>(&id), sizeof(id), MPI_BYTE, 0, MPICHECK(MPI_Bcast(static_cast<void*>(&id), sizeof(id), MPI_BYTE, 0,
MPI_COMM_WORLD)); MPI_COMM_WORLD));
NCCLCHECK(ncclCommInitRank(&comm, nRanks, id, myRank)); NCCLCHECK(ncclCommInitRank(&comm, nRanks, id, myRank));
bool performance_test = true; bool performance_test = true;
cudaProfilerStart(); cudaProfilerStart();
// Uncomment to scan through different block size configs. // for (int threads : {256, 512}) {
// for (int threads : {256, 512, 1024}) {
// for (int block_limit = 16; block_limit < 112; block_limit += 4) { // for (int block_limit = 16; block_limit < 112; block_limit += 4) {
// run<half>(myRank, nRanks, comm, threads, block_limit, 1024 * 1024, // run<half>(myRank, nRanks, comm, threads, block_limit, 4096 * 1024);
// performance_test);
// } // }
// } // }
// Scan through different sizes to test performance. #ifdef USE_ROCM
for (int sz = 512; sz <= (8 << 20); sz *= 2) {
run<half>(myRank, nRanks, comm, 512, 16, sz + 8 * 47, performance_test);
}
#else
for (int sz = 512; sz <= (8 << 20); sz *= 2) { for (int sz = 512; sz <= (8 << 20); sz *= 2) {
run<half>(myRank, nRanks, comm, 512, 36, sz + 8 * 47, performance_test); run<half>(myRank, nRanks, comm, 512, 36, sz + 8 * 47, performance_test);
} }
#endif
cudaProfilerStop(); cudaProfilerStop();
MPICHECK(MPI_Finalize()); MPICHECK(MPI_Finalize());
return EXIT_SUCCESS; return EXIT_SUCCESS;
} }
...@@ -122,8 +122,8 @@ struct ScaledEpilogue ...@@ -122,8 +122,8 @@ struct ScaledEpilogue
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales); auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales); auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
typename EVTCompute0::Arguments evt0_args{b_args}; typename EVTCompute0::Arguments evt0_args{b_args, {}, {}};
return ArgumentType{a_args, evt0_args}; return ArgumentType{a_args, evt0_args, {}};
} }
}; };
...@@ -167,8 +167,8 @@ struct ScaledEpilogueBias ...@@ -167,8 +167,8 @@ struct ScaledEpilogueBias
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales); auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias); auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias);
typename EVTCompute0::Arguments evt0_args{b_args}; typename EVTCompute0::Arguments evt0_args{b_args, {}, {}};
return ArgumentType{a_args, evt0_args, bias_args}; return ArgumentType{a_args, evt0_args, bias_args, {}};
} }
}; };
...@@ -230,9 +230,10 @@ struct ScaledEpilogueBiasAzp ...@@ -230,9 +230,10 @@ struct ScaledEpilogueBiasAzp
auto azp_adj_args = auto azp_adj_args =
SUPER::template args_from_tensor<AzpWithAdj, int32_t>(azp_adj); SUPER::template args_from_tensor<AzpWithAdj, int32_t>(azp_adj);
typename EVTComputeAzp::Arguments evt_azp_args{{}, azp_adj_args}; typename EVTComputeAzp::Arguments evt_azp_args{{}, azp_adj_args, {}};
typename EVTComputeScaleB::Arguments evt_scale_b_args{b_args, evt_azp_args}; typename EVTComputeScaleB::Arguments evt_scale_b_args{
return ArgumentType{a_args, evt_scale_b_args, bias_args}; b_args, evt_azp_args, {}};
return ArgumentType{a_args, evt_scale_b_args, bias_args, {}};
} }
}; };
...@@ -309,11 +310,12 @@ struct ScaledEpilogueBiasAzpToken ...@@ -309,11 +310,12 @@ struct ScaledEpilogueBiasAzpToken
auto azp_adj_args = auto azp_adj_args =
SUPER::template args_from_tensor<AzpAdj, int32_t>(azp_adj); SUPER::template args_from_tensor<AzpAdj, int32_t>(azp_adj);
typename EVTComputeAzp::Arguments evt_azp_args{azp_args, azp_adj_args}; typename EVTComputeAzp::Arguments evt_azp_args{azp_args, azp_adj_args, {}};
typename EVTComputeAcc::Arguments evt_acc_args{{}, evt_azp_args}; typename EVTComputeAcc::Arguments evt_acc_args{{}, evt_azp_args, {}};
typename EVTComputeScaleB::Arguments evt_scale_b_args{b_args, evt_acc_args}; typename EVTComputeScaleB::Arguments evt_scale_b_args{
return ArgumentType{a_args, evt_scale_b_args, bias_args}; b_args, evt_acc_args, {}};
return ArgumentType{a_args, evt_scale_b_args, bias_args, {}};
} }
}; };
}; // namespace vllm::c2x }; // namespace vllm::c2x
\ No newline at end of file
...@@ -22,7 +22,7 @@ struct identity { ...@@ -22,7 +22,7 @@ struct identity {
T operator()(T lhs) const { return lhs; } T operator()(T lhs) const { return lhs; }
}; };
template <typename ElementAcc, typename ElementD, typename EpilogueDescriptor> template <typename ElementAcc, typename ElementD, typename TileShape>
struct TrivialEpilogue { struct TrivialEpilogue {
private: private:
using Accum = cutlass::epilogue::fusion::Sm90AccFetch; using Accum = cutlass::epilogue::fusion::Sm90AccFetch;
...@@ -44,32 +44,30 @@ struct TrivialEpilogue { ...@@ -44,32 +44,30 @@ struct TrivialEpilogue {
* This class provides the common load descriptors for the * This class provides the common load descriptors for the
* ScaledEpilogue[...] classes * ScaledEpilogue[...] classes
*/ */
template <typename ElementAcc, typename ElementD, typename EpilogueDescriptor> template <typename ElementAcc, typename ElementD, typename TileShape>
struct ScaledEpilogueBase { struct ScaledEpilogueBase {
protected: protected:
using Accum = cutlass::epilogue::fusion::Sm90AccFetch; using Accum = cutlass::epilogue::fusion::Sm90AccFetch;
template <typename T> template <typename T>
using ColOrScalarLoad = cutlass::epilogue::fusion::Sm90ColOrScalarBroadcast< using ColOrScalarLoad = cutlass::epilogue::fusion::Sm90ColOrScalarBroadcast<
0 /*Stages*/, typename EpilogueDescriptor::TileShape, T, 0 /*Stages*/, TileShape, T, Stride<Int<1>, Int<0>, Int<0>>>;
Stride<Int<1>, Int<0>, Int<0>>>;
template <typename T> template <typename T>
using RowOrScalarLoad = cutlass::epilogue::fusion::Sm90RowOrScalarBroadcast< using RowOrScalarLoad = cutlass::epilogue::fusion::Sm90RowOrScalarBroadcast<
0 /*Stages*/, typename EpilogueDescriptor::TileShape, T, 0 /*Stages*/, TileShape, T, Stride<Int<0>, Int<1>, Int<0>>>;
Stride<Int<0>, Int<1>, Int<0>>>;
// Don't want to support nullptr by default // Don't want to support nullptr by default
template <typename T, bool EnableNullPtr = false> template <typename T, bool EnableNullPtr = false>
using ColLoad = cutlass::epilogue::fusion::Sm90ColBroadcast< using ColLoad = cutlass::epilogue::fusion::Sm90ColBroadcast<
0 /*Stages*/, typename EpilogueDescriptor::TileShape, T, T, 0 /*Stages*/, TileShape, T, T, Stride<Int<1>, Int<0>, Int<0>>,
Stride<Int<1>, Int<0>, Int<0>>, 128 / sizeof_bits_v<T>, EnableNullPtr>; 128 / sizeof_bits_v<T>, EnableNullPtr>;
// Don't want to support nullptr by default // Don't want to support nullptr by default
template <typename T, bool EnableNullPtr = false> template <typename T, bool EnableNullPtr = false>
using RowLoad = cutlass::epilogue::fusion::Sm90RowBroadcast< using RowLoad = cutlass::epilogue::fusion::Sm90RowBroadcast<
0 /*Stages*/, typename EpilogueDescriptor::TileShape, T, T, 0 /*Stages*/, TileShape, T, T, Stride<Int<0>, Int<1>, Int<0>>,
Stride<Int<0>, Int<1>, Int<0>>, 128 / sizeof_bits_v<T>, EnableNullPtr>; 128 / sizeof_bits_v<T>, EnableNullPtr>;
// This utility function constructs the arguments for the load descriptors // This utility function constructs the arguments for the load descriptors
// from a tensor. It can handle both row and column, as well as row/column or // from a tensor. It can handle both row and column, as well as row/column or
...@@ -116,11 +114,11 @@ struct ScaledEpilogueBase { ...@@ -116,11 +114,11 @@ struct ScaledEpilogueBase {
the A and B operands respectively. These scales may be either per-tensor or the A and B operands respectively. These scales may be either per-tensor or
per row or column. per row or column.
*/ */
template <typename ElementAcc, typename ElementD, typename EpilogueDescriptor> template <typename ElementAcc, typename ElementD, typename TileShape>
struct ScaledEpilogue struct ScaledEpilogue
: private ScaledEpilogueBase<ElementAcc, ElementD, EpilogueDescriptor> { : private ScaledEpilogueBase<ElementAcc, ElementD, TileShape> {
private: private:
using SUPER = ScaledEpilogueBase<ElementAcc, ElementD, EpilogueDescriptor>; using SUPER = ScaledEpilogueBase<ElementAcc, ElementD, TileShape>;
using Accum = typename SUPER::Accum; using Accum = typename SUPER::Accum;
using ScaleA = typename SUPER::template ColOrScalarLoad<float>; using ScaleA = typename SUPER::template ColOrScalarLoad<float>;
using ScaleB = typename SUPER::template RowOrScalarLoad<float>; using ScaleB = typename SUPER::template RowOrScalarLoad<float>;
...@@ -146,8 +144,8 @@ struct ScaledEpilogue ...@@ -146,8 +144,8 @@ struct ScaledEpilogue
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales); auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales); auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
typename EVTCompute0::Arguments evt0_args{b_args}; typename EVTCompute0::Arguments evt0_args{b_args, {}, {}};
return ArgumentType{a_args, evt0_args}; return ArgumentType{a_args, evt0_args, {}};
} }
}; };
...@@ -160,11 +158,11 @@ struct ScaledEpilogue ...@@ -160,11 +158,11 @@ struct ScaledEpilogue
* The bias tensor must be per-output channel. * The bias tensor must be per-output channel.
* ScaleA and ScaleB can be per-tensor or per-token/per-channel. * ScaleA and ScaleB can be per-tensor or per-token/per-channel.
*/ */
template <typename ElementAcc, typename ElementD, typename EpilogueDescriptor> template <typename ElementAcc, typename ElementD, typename TileShape>
struct ScaledEpilogueBias struct ScaledEpilogueBias
: private ScaledEpilogueBase<ElementAcc, ElementD, EpilogueDescriptor> { : private ScaledEpilogueBase<ElementAcc, ElementD, TileShape> {
private: private:
using SUPER = ScaledEpilogueBase<ElementAcc, ElementD, EpilogueDescriptor>; using SUPER = ScaledEpilogueBase<ElementAcc, ElementD, TileShape>;
using Accum = typename SUPER::Accum; using Accum = typename SUPER::Accum;
using ScaleA = typename SUPER::template ColOrScalarLoad<float>; using ScaleA = typename SUPER::template ColOrScalarLoad<float>;
using ScaleB = typename SUPER::template RowOrScalarLoad<float>; using ScaleB = typename SUPER::template RowOrScalarLoad<float>;
...@@ -193,8 +191,8 @@ struct ScaledEpilogueBias ...@@ -193,8 +191,8 @@ struct ScaledEpilogueBias
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales); auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias); auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias);
typename EVTCompute0::Arguments evt0_args{b_args}; typename EVTCompute0::Arguments evt0_args{b_args, {}, {}};
return ArgumentType{a_args, evt0_args, bias_args}; return ArgumentType{a_args, evt0_args, bias_args, {}};
} }
}; };
...@@ -203,11 +201,11 @@ struct ScaledEpilogueBias ...@@ -203,11 +201,11 @@ struct ScaledEpilogueBias
* bias is a column vector instead of a row vector. Useful e.g. if we are * bias is a column vector instead of a row vector. Useful e.g. if we are
* computing a GEMM via C^T += B^T A^T. This happens in the 2:4 sparse kernels. * computing a GEMM via C^T += B^T A^T. This happens in the 2:4 sparse kernels.
*/ */
template <typename ElementAcc, typename ElementD, typename EpilogueDescriptor> template <typename ElementAcc, typename ElementD, typename TileShape>
struct ScaledEpilogueColumnBias struct ScaledEpilogueColumnBias
: private ScaledEpilogueBase<ElementAcc, ElementD, EpilogueDescriptor> { : private ScaledEpilogueBase<ElementAcc, ElementD, TileShape> {
private: private:
using SUPER = ScaledEpilogueBase<ElementAcc, ElementD, EpilogueDescriptor>; using SUPER = ScaledEpilogueBase<ElementAcc, ElementD, TileShape>;
using Accum = typename SUPER::Accum; using Accum = typename SUPER::Accum;
using ScaleA = typename SUPER::template ColOrScalarLoad<float>; using ScaleA = typename SUPER::template ColOrScalarLoad<float>;
using ScaleB = typename SUPER::template RowOrScalarLoad<float>; using ScaleB = typename SUPER::template RowOrScalarLoad<float>;
...@@ -236,8 +234,8 @@ struct ScaledEpilogueColumnBias ...@@ -236,8 +234,8 @@ struct ScaledEpilogueColumnBias
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales); auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias); auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias);
typename EVTCompute0::Arguments evt0_args{b_args}; typename EVTCompute0::Arguments evt0_args{b_args, {}, {}};
return ArgumentType{a_args, evt0_args, bias_args}; return ArgumentType{a_args, evt0_args, bias_args, {}};
} }
}; };
...@@ -249,11 +247,11 @@ struct ScaledEpilogueColumnBias ...@@ -249,11 +247,11 @@ struct ScaledEpilogueColumnBias
* *
* This epilogue also supports bias, which remains per-channel. * This epilogue also supports bias, which remains per-channel.
*/ */
template <typename ElementAcc, typename ElementD, typename EpilogueDescriptor> template <typename ElementAcc, typename ElementD, typename TileShape>
struct ScaledEpilogueBiasAzp struct ScaledEpilogueBiasAzp
: private ScaledEpilogueBase<ElementAcc, ElementD, EpilogueDescriptor> { : private ScaledEpilogueBase<ElementAcc, ElementD, TileShape> {
private: private:
using SUPER = ScaledEpilogueBase<ElementAcc, ElementD, EpilogueDescriptor>; using SUPER = ScaledEpilogueBase<ElementAcc, ElementD, TileShape>;
using Accum = typename SUPER::Accum; using Accum = typename SUPER::Accum;
using ScaleA = typename SUPER::template ColOrScalarLoad<float>; using ScaleA = typename SUPER::template ColOrScalarLoad<float>;
using ScaleB = typename SUPER::template RowOrScalarLoad<float>; using ScaleB = typename SUPER::template RowOrScalarLoad<float>;
...@@ -297,9 +295,10 @@ struct ScaledEpilogueBiasAzp ...@@ -297,9 +295,10 @@ struct ScaledEpilogueBiasAzp
auto azp_adj_args = auto azp_adj_args =
SUPER::template args_from_tensor<AzpWithAdj, int32_t>(azp_adj); SUPER::template args_from_tensor<AzpWithAdj, int32_t>(azp_adj);
typename EVTComputeAzp::Arguments evt_azp_args{{}, azp_adj_args}; typename EVTComputeAzp::Arguments evt_azp_args{{}, azp_adj_args, {}};
typename EVTComputeScaleB::Arguments evt_scale_b_args{b_args, evt_azp_args}; typename EVTComputeScaleB::Arguments evt_scale_b_args{
return ArgumentType{a_args, evt_scale_b_args, bias_args}; b_args, evt_azp_args, {}};
return ArgumentType{a_args, evt_scale_b_args, bias_args, {}};
} }
}; };
...@@ -313,11 +312,11 @@ struct ScaledEpilogueBiasAzp ...@@ -313,11 +312,11 @@ struct ScaledEpilogueBiasAzp
* *
* This epilogue also supports bias, which remains per-channel. * This epilogue also supports bias, which remains per-channel.
*/ */
template <typename ElementAcc, typename ElementD, typename EpilogueDescriptor> template <typename ElementAcc, typename ElementD, typename TileShape>
struct ScaledEpilogueBiasAzpToken struct ScaledEpilogueBiasAzpToken
: private ScaledEpilogueBase<ElementAcc, ElementD, EpilogueDescriptor> { : private ScaledEpilogueBase<ElementAcc, ElementD, TileShape> {
private: private:
using SUPER = ScaledEpilogueBase<ElementAcc, ElementD, EpilogueDescriptor>; using SUPER = ScaledEpilogueBase<ElementAcc, ElementD, TileShape>;
using Accum = typename SUPER::Accum; using Accum = typename SUPER::Accum;
using ScaleA = typename SUPER::template ColOrScalarLoad<float>; using ScaleA = typename SUPER::template ColOrScalarLoad<float>;
using ScaleB = typename SUPER::template RowOrScalarLoad<float>; using ScaleB = typename SUPER::template RowOrScalarLoad<float>;
...@@ -374,10 +373,11 @@ struct ScaledEpilogueBiasAzpToken ...@@ -374,10 +373,11 @@ struct ScaledEpilogueBiasAzpToken
auto azp_adj_args = auto azp_adj_args =
SUPER::template args_from_tensor<AzpAdj, int32_t>(azp_adj); SUPER::template args_from_tensor<AzpAdj, int32_t>(azp_adj);
typename EVTComputeAzp::Arguments evt_azp_args{azp_args, azp_adj_args}; typename EVTComputeAzp::Arguments evt_azp_args{azp_args, azp_adj_args, {}};
typename EVTComputeAcc::Arguments evt_acc_args{{}, evt_azp_args}; typename EVTComputeAcc::Arguments evt_acc_args{{}, evt_azp_args, {}};
typename EVTComputeScaleB::Arguments evt_scale_b_args{b_args, evt_acc_args}; typename EVTComputeScaleB::Arguments evt_scale_b_args{
return ArgumentType{a_args, evt_scale_b_args, bias_args}; b_args, evt_acc_args, {}};
return ArgumentType{a_args, evt_scale_b_args, bias_args, {}};
} }
}; };
......
...@@ -402,7 +402,7 @@ struct CollectiveMma< ...@@ -402,7 +402,7 @@ struct CollectiveMma<
// TODO: test `scale_copy_a` with `ScaleMsPerTile` < 128 // TODO: test `scale_copy_a` with `ScaleMsPerTile` < 128
TiledCopy scale_copy_a = make_tiled_copy(SmemBlockScalingCopyAtomA{}, TiledCopy scale_copy_a = make_tiled_copy(SmemBlockScalingCopyAtomA{},
Layout<Shape<_32, _1>>{}, Layout<Shape<_4, _1>>{}); // (1,1,1) Layout<Shape<_32>>{}, Layout<Shape<_1>>{}); // (1,1,1)
TiledCopy scale_copy_b = make_tiled_copy(SmemBlockScalingCopyAtomB{}, TiledCopy scale_copy_b = make_tiled_copy(SmemBlockScalingCopyAtomB{},
Layout<Shape<_1>>{}, Layout<Shape<_1>>{}); // (1,1,1) Layout<Shape<_1>>{}, Layout<Shape<_1>>{}); // (1,1,1)
ThrCopy thr_scale_copy_a = scale_copy_a.get_slice(threadIdx.x); ThrCopy thr_scale_copy_a = scale_copy_a.get_slice(threadIdx.x);
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import enum import enum
from typing import Dict, Union from typing import Union
from cutlass_library import * from cutlass_library import *
...@@ -21,7 +21,7 @@ class MixedInputKernelScheduleType(enum.Enum): ...@@ -21,7 +21,7 @@ class MixedInputKernelScheduleType(enum.Enum):
TmaWarpSpecializedCooperative = enum_auto() TmaWarpSpecializedCooperative = enum_auto()
VLLMDataTypeNames: Dict[Union[VLLMDataType, DataType], str] = { VLLMDataTypeNames: dict[Union[VLLMDataType, DataType], str] = {
**DataTypeNames, # type: ignore **DataTypeNames, # type: ignore
**{ **{
VLLMDataType.u4b8: "u4b8", VLLMDataType.u4b8: "u4b8",
...@@ -29,7 +29,7 @@ VLLMDataTypeNames: Dict[Union[VLLMDataType, DataType], str] = { ...@@ -29,7 +29,7 @@ VLLMDataTypeNames: Dict[Union[VLLMDataType, DataType], str] = {
} }
} }
VLLMDataTypeTag: Dict[Union[VLLMDataType, DataType], str] = { VLLMDataTypeTag: dict[Union[VLLMDataType, DataType], str] = {
**DataTypeTag, # type: ignore **DataTypeTag, # type: ignore
**{ **{
VLLMDataType.u4b8: "cutlass::vllm_uint4b8_t", VLLMDataType.u4b8: "cutlass::vllm_uint4b8_t",
...@@ -37,7 +37,7 @@ VLLMDataTypeTag: Dict[Union[VLLMDataType, DataType], str] = { ...@@ -37,7 +37,7 @@ VLLMDataTypeTag: Dict[Union[VLLMDataType, DataType], str] = {
} }
} }
VLLMDataTypeSize: Dict[Union[VLLMDataType, DataType], int] = { VLLMDataTypeSize: dict[Union[VLLMDataType, DataType], int] = {
**DataTypeSize, # type: ignore **DataTypeSize, # type: ignore
**{ **{
VLLMDataType.u4b8: 4, VLLMDataType.u4b8: 4,
...@@ -45,7 +45,7 @@ VLLMDataTypeSize: Dict[Union[VLLMDataType, DataType], int] = { ...@@ -45,7 +45,7 @@ VLLMDataTypeSize: Dict[Union[VLLMDataType, DataType], int] = {
} }
} }
VLLMDataTypeVLLMScalarTypeTag: Dict[Union[VLLMDataType, DataType], str] = { VLLMDataTypeVLLMScalarTypeTag: dict[Union[VLLMDataType, DataType], str] = {
VLLMDataType.u4b8: "vllm::kU4B8", VLLMDataType.u4b8: "vllm::kU4B8",
VLLMDataType.u8b128: "vllm::kU8B128", VLLMDataType.u8b128: "vllm::kU8B128",
DataType.u4: "vllm::kU4", DataType.u4: "vllm::kU4",
...@@ -56,7 +56,7 @@ VLLMDataTypeVLLMScalarTypeTag: Dict[Union[VLLMDataType, DataType], str] = { ...@@ -56,7 +56,7 @@ VLLMDataTypeVLLMScalarTypeTag: Dict[Union[VLLMDataType, DataType], str] = {
DataType.bf16: "vllm::kBfloat16", DataType.bf16: "vllm::kBfloat16",
} }
VLLMDataTypeTorchDataTypeTag: Dict[Union[VLLMDataType, DataType], str] = { VLLMDataTypeTorchDataTypeTag: dict[Union[VLLMDataType, DataType], str] = {
DataType.u8: "at::ScalarType::Byte", DataType.u8: "at::ScalarType::Byte",
DataType.s8: "at::ScalarType::Char", DataType.s8: "at::ScalarType::Char",
DataType.e4m3: "at::ScalarType::Float8_e4m3fn", DataType.e4m3: "at::ScalarType::Float8_e4m3fn",
...@@ -66,7 +66,7 @@ VLLMDataTypeTorchDataTypeTag: Dict[Union[VLLMDataType, DataType], str] = { ...@@ -66,7 +66,7 @@ VLLMDataTypeTorchDataTypeTag: Dict[Union[VLLMDataType, DataType], str] = {
DataType.f32: "at::ScalarType::Float", DataType.f32: "at::ScalarType::Float",
} }
VLLMKernelScheduleTag: Dict[Union[ VLLMKernelScheduleTag: dict[Union[
MixedInputKernelScheduleType, KernelScheduleType], str] = { MixedInputKernelScheduleType, KernelScheduleType], str] = {
**KernelScheduleTag, # type: ignore **KernelScheduleTag, # type: ignore
**{ **{
......
...@@ -6,6 +6,11 @@ ...@@ -6,6 +6,11 @@
#include <torch/all.h> #include <torch/all.h>
// Need a special dispatch case macro since we will nest the FP8 dispatch.
// Instead of the usual 'scalar_t', this names the dispatched type 'fp8_t'.
#define AT_DISPATCH_FP8_CASE(enum_type, ...) \
AT_PRIVATE_CASE_TYPE_USING_HINT(enum_type, fp8_t, __VA_ARGS__)
#define VLLM_DISPATCH_CASE_FLOATING_TYPES(...) \ #define VLLM_DISPATCH_CASE_FLOATING_TYPES(...) \
AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \ AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \ AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \
...@@ -14,17 +19,32 @@ ...@@ -14,17 +19,32 @@
#define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \ #define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \
AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__)) AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__))
// TODO(luka/varun): use FP8_TYPE macro after refactoring // ROCm devices might use either fn or fnuz, so set up dispatch table for both.
#ifndef USE_ROCM // A host-based check at runtime will create a preferred FP8 type for ROCm
#define VLLM_DISPATCH_CASE_QUANT_TYPES(...) \ // such that the correct kernel is dispatched.
AT_DISPATCH_CASE(at::ScalarType::Float8_e4m3fn, __VA_ARGS__) \ #ifdef USE_ROCM
AT_DISPATCH_CASE(at::ScalarType::Char, __VA_ARGS__) #define VLLM_DISPATCH_CASE_FP8_TYPES(...) \
#else AT_DISPATCH_FP8_CASE(at::ScalarType::Float8_e4m3fn, __VA_ARGS__) \
AT_DISPATCH_FP8_CASE(at::ScalarType::Float8_e4m3fnuz, __VA_ARGS__)
#define VLLM_DISPATCH_CASE_QUANT_TYPES(...) \ #define VLLM_DISPATCH_CASE_QUANT_TYPES(...) \
AT_DISPATCH_CASE(at::ScalarType::Float8_e4m3fn, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Float8_e4m3fnuz, __VA_ARGS__) \ AT_DISPATCH_CASE(at::ScalarType::Float8_e4m3fnuz, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Char, __VA_ARGS__) AT_DISPATCH_CASE(at::ScalarType::Char, __VA_ARGS__)
#else
#define VLLM_DISPATCH_CASE_FP8_TYPES(...) \
AT_DISPATCH_FP8_CASE(at::ScalarType::Float8_e4m3fn, __VA_ARGS__)
#define VLLM_DISPATCH_CASE_QUANT_TYPES(...) \
AT_DISPATCH_CASE(at::ScalarType::Float8_e4m3fn, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Char, __VA_ARGS__)
#endif #endif
// When using this dispatch macro, the type is 'fp8_t' not 'scalar_t'.
// See AT_DISPATCH_FP8_CASE above.
#define VLLM_DISPATCH_FP8_TYPES(TYPE, NAME, ...) \
AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_FP8_TYPES(__VA_ARGS__))
#define VLLM_DISPATCH_QUANT_TYPES(TYPE, NAME, ...) \ #define VLLM_DISPATCH_QUANT_TYPES(TYPE, NAME, ...) \
AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_QUANT_TYPES(__VA_ARGS__)) AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_QUANT_TYPES(__VA_ARGS__))
......
...@@ -21,9 +21,9 @@ ...@@ -21,9 +21,9 @@
namespace vllm { namespace vllm {
// TODO(woosuk): Further optimize this kernel. // TODO(woosuk): Further optimize this kernel.
template <typename scalar_t> template <typename scalar_t, typename fp8_type>
__global__ void rms_norm_static_fp8_quant_kernel( __global__ void rms_norm_static_fp8_quant_kernel(
FP8_TYPE* __restrict__ out, // [..., hidden_size] fp8_type* __restrict__ out, // [..., hidden_size]
const scalar_t* __restrict__ input, // [..., hidden_size] const scalar_t* __restrict__ input, // [..., hidden_size]
const scalar_t* __restrict__ weight, // [hidden_size] const scalar_t* __restrict__ weight, // [hidden_size]
const float* __restrict__ scale, // [1] const float* __restrict__ scale, // [1]
...@@ -52,7 +52,7 @@ __global__ void rms_norm_static_fp8_quant_kernel( ...@@ -52,7 +52,7 @@ __global__ void rms_norm_static_fp8_quant_kernel(
float x = (float)input[blockIdx.x * hidden_size + idx]; float x = (float)input[blockIdx.x * hidden_size + idx];
float const out_norm = ((scalar_t)(x * s_variance)) * weight[idx]; float const out_norm = ((scalar_t)(x * s_variance)) * weight[idx];
out[blockIdx.x * hidden_size + idx] = out[blockIdx.x * hidden_size + idx] =
scaled_fp8_conversion<true>(out_norm, scale_inv); scaled_fp8_conversion<true, fp8_type>(out_norm, scale_inv);
} }
} }
...@@ -60,10 +60,10 @@ __global__ void rms_norm_static_fp8_quant_kernel( ...@@ -60,10 +60,10 @@ __global__ void rms_norm_static_fp8_quant_kernel(
Additional optimizations we can make in this case are Additional optimizations we can make in this case are
packed and vectorized operations, which help with the packed and vectorized operations, which help with the
memory latency bottleneck. */ memory latency bottleneck. */
template <typename scalar_t, int width> template <typename scalar_t, int width, typename fp8_type>
__global__ std::enable_if_t<(width > 0) && _typeConvert<scalar_t>::exists> __global__ std::enable_if_t<(width > 0) && _typeConvert<scalar_t>::exists>
fused_add_rms_norm_static_fp8_quant_kernel( fused_add_rms_norm_static_fp8_quant_kernel(
FP8_TYPE* __restrict__ out, // [..., hidden_size] fp8_type* __restrict__ out, // [..., hidden_size]
scalar_t* __restrict__ input, // [..., hidden_size] scalar_t* __restrict__ input, // [..., hidden_size]
scalar_t* __restrict__ residual, // [..., hidden_size] scalar_t* __restrict__ residual, // [..., hidden_size]
const scalar_t* __restrict__ weight, // [hidden_size] const scalar_t* __restrict__ weight, // [hidden_size]
...@@ -114,7 +114,7 @@ fused_add_rms_norm_static_fp8_quant_kernel( ...@@ -114,7 +114,7 @@ fused_add_rms_norm_static_fp8_quant_kernel(
#pragma unroll #pragma unroll
for (int i = 0; i < width; ++i) { for (int i = 0; i < width; ++i) {
out[id * width + i] = out[id * width + i] =
scaled_fp8_conversion<true>(float(temp.data[i]), scale_inv); scaled_fp8_conversion<true, fp8_type>(float(temp.data[i]), scale_inv);
} }
} }
} }
...@@ -122,10 +122,10 @@ fused_add_rms_norm_static_fp8_quant_kernel( ...@@ -122,10 +122,10 @@ fused_add_rms_norm_static_fp8_quant_kernel(
/* Generic fused_add_rms_norm_kernel /* Generic fused_add_rms_norm_kernel
The width field is not used here but necessary for other specializations. The width field is not used here but necessary for other specializations.
*/ */
template <typename scalar_t, int width> template <typename scalar_t, int width, typename fp8_type>
__global__ std::enable_if_t<(width == 0) || !_typeConvert<scalar_t>::exists> __global__ std::enable_if_t<(width == 0) || !_typeConvert<scalar_t>::exists>
fused_add_rms_norm_static_fp8_quant_kernel( fused_add_rms_norm_static_fp8_quant_kernel(
FP8_TYPE* __restrict__ out, // [..., hidden_size] fp8_type* __restrict__ out, // [..., hidden_size]
scalar_t* __restrict__ input, // [..., hidden_size] scalar_t* __restrict__ input, // [..., hidden_size]
scalar_t* __restrict__ residual, // [..., hidden_size] scalar_t* __restrict__ residual, // [..., hidden_size]
const scalar_t* __restrict__ weight, // [hidden_size] const scalar_t* __restrict__ weight, // [hidden_size]
...@@ -158,7 +158,7 @@ fused_add_rms_norm_static_fp8_quant_kernel( ...@@ -158,7 +158,7 @@ fused_add_rms_norm_static_fp8_quant_kernel(
float x = (float)residual[blockIdx.x * hidden_size + idx]; float x = (float)residual[blockIdx.x * hidden_size + idx];
float const out_norm = ((scalar_t)(x * s_variance)) * weight[idx]; float const out_norm = ((scalar_t)(x * s_variance)) * weight[idx];
out[blockIdx.x * hidden_size + idx] = out[blockIdx.x * hidden_size + idx] =
scaled_fp8_conversion<true>(out_norm, scale_inv); scaled_fp8_conversion<true, fp8_type>(out_norm, scale_inv);
} }
} }
...@@ -176,25 +176,33 @@ void rms_norm_static_fp8_quant(torch::Tensor& out, // [..., hidden_size] ...@@ -176,25 +176,33 @@ void rms_norm_static_fp8_quant(torch::Tensor& out, // [..., hidden_size]
dim3 block(std::min(hidden_size, 1024)); dim3 block(std::min(hidden_size, 1024));
const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "rms_norm_kernel", [&] { VLLM_DISPATCH_FLOATING_TYPES(
vllm::rms_norm_static_fp8_quant_kernel<scalar_t> input.scalar_type(), "rms_norm_kernel_scalar_type", [&] {
<<<grid, block, 0, stream>>>( VLLM_DISPATCH_FP8_TYPES(
out.data_ptr<FP8_TYPE>(), input.data_ptr<scalar_t>(), out.scalar_type(), "rms_norm_kernel_fp8_type", [&] {
weight.data_ptr<scalar_t>(), scale.data_ptr<float>(), epsilon, vllm::rms_norm_static_fp8_quant_kernel<scalar_t, fp8_t>
num_tokens, hidden_size); <<<grid, block, 0, stream>>>(
}); out.data_ptr<fp8_t>(), input.data_ptr<scalar_t>(),
weight.data_ptr<scalar_t>(), scale.data_ptr<float>(),
epsilon, num_tokens, hidden_size);
});
});
} }
#define LAUNCH_FUSED_ADD_RMS_NORM(width) \ #define LAUNCH_FUSED_ADD_RMS_NORM(width) \
VLLM_DISPATCH_FLOATING_TYPES( \ VLLM_DISPATCH_FLOATING_TYPES( \
input.scalar_type(), "fused_add_rms_norm_kernel", [&] { \ input.scalar_type(), "fused_add_rms_norm_kernel_scalar_type", [&] { \
vllm::fused_add_rms_norm_static_fp8_quant_kernel<scalar_t, width> \ VLLM_DISPATCH_FP8_TYPES( \
<<<grid, block, 0, stream>>>( \ out.scalar_type(), "fused_add_rms_norm_kernel_fp8_type", [&] { \
out.data_ptr<FP8_TYPE>(), input.data_ptr<scalar_t>(), \ vllm::fused_add_rms_norm_static_fp8_quant_kernel<scalar_t, \
residual.data_ptr<scalar_t>(), weight.data_ptr<scalar_t>(), \ width, fp8_t> \
scale.data_ptr<float>(), epsilon, num_tokens, hidden_size); \ <<<grid, block, 0, stream>>>( \
out.data_ptr<fp8_t>(), input.data_ptr<scalar_t>(), \
residual.data_ptr<scalar_t>(), \
weight.data_ptr<scalar_t>(), scale.data_ptr<float>(), \
epsilon, num_tokens, hidden_size); \
}); \
}); });
void fused_add_rms_norm_static_fp8_quant( void fused_add_rms_norm_static_fp8_quant(
torch::Tensor& out, // [..., hidden_size], torch::Tensor& out, // [..., hidden_size],
torch::Tensor& input, // [..., hidden_size] torch::Tensor& input, // [..., hidden_size]
......
...@@ -24,3 +24,14 @@ void sgl_moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts, ...@@ -24,3 +24,14 @@ void sgl_moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
torch::Tensor sorted_token_ids, torch::Tensor sorted_token_ids,
torch::Tensor experts_ids, torch::Tensor experts_ids,
torch::Tensor num_tokens_post_pad); torch::Tensor num_tokens_post_pad);
#ifndef USE_ROCM
torch::Tensor moe_wna16_gemm(torch::Tensor input, torch::Tensor output,
torch::Tensor b_qweight, torch::Tensor b_scales,
std::optional<torch::Tensor> b_qzeros,
std::optional<torch::Tensor> topk_weights,
torch::Tensor sorted_token_ids,
torch::Tensor expert_ids,
torch::Tensor num_tokens_post_pad, int64_t top_k,
int64_t BLOCK_SIZE_M, int64_t BLOCK_SIZE_N,
int64_t BLOCK_SIZE_K, int64_t bit);
#endif
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment