Commit 02689420 authored by xuxz's avatar xuxz
Browse files

Merge branch 'v0.9.2-dev' into 'v0.9.2-dev-add_connector'

# Conflicts:
#   vllm/distributed/kv_transfer/kv_connector/v1/p2p/p2p_nccl_engine.py
parents ef362942 fa683b07
......@@ -1056,9 +1056,21 @@ class CustomAllreduce {
size /= d;
auto bytes = size * sizeof(typename packed_t<T>::P);
int blocks = std::min(block_limit, (size + threads - 1) / threads);
// #define KL(ngpus, name) \
// name<T, ngpus><<<blocks, threads, 0, stream>>>(ptrs, sg_, self_sg_, output, \
// rank_, size, dev_curr_hdp_reg, world_size_) ;
#define KL(ngpus, name) \
name<T, ngpus><<<blocks, threads, 0, stream>>>(ptrs, sg_, self_sg_, output, \
rank_, size, dev_curr_hdp_reg, world_size_) ;
{ \
void* kernelArgs[] = { \
&ptrs, &sg_, &self_sg_, &output, &rank_, &size \
}; \
hipExtLaunchKernel( \
(void*)name<T, ngpus>, \
blocks, threads, \
kernelArgs, 0, \
stream, nullptr, stopEvent, 0 \
); \
}
#define REDUCE_CASE(ngpus) \
case ngpus: { \
......
......@@ -173,6 +173,39 @@ __global__ void moe_sum_kernel(
}
}
template <typename scalar_t, int TOPK, int SPLIT_D, int BLOCK_DIM>
__global__ void moe_sum_sharedmem(
scalar_t* __restrict__ out,
const scalar_t* __restrict__ input,
const int d) {
const int token_idx = blockIdx.x / SPLIT_D;
const int sub_block = blockIdx.x % SPLIT_D;
const int d_per_block = (d + SPLIT_D - 1) / SPLIT_D;
const int64_t d_start = sub_block * d_per_block;
const int64_t d_end = min(d_start + d_per_block, d);
const int64_t token_offset = token_idx * TOPK * d;
__shared__ __align__(16) scalar_t sem_input[TOPK][BLOCK_DIM];
for (int64_t idx = d_start + threadIdx.x; idx < d_end; idx += blockDim.x) {
#pragma unroll
for (int k = 0; k < TOPK; ++k) {
sem_input[k][threadIdx.x] =
input[token_offset + k * d + idx];
}
__syncthreads();
scalar_t x = 0;
#pragma unroll
for (int k = 0; k < TOPK; ++k) {
x += sem_input[k][threadIdx.x];
}
out[token_idx * d + idx] = x;
}
}
template <typename scalar_t, int TOPK, int SPLIT_D, int BLOCK_DIM>
__global__ void moe_sum_sharedmem_topk8(
scalar_t* __restrict__ out,
......@@ -440,7 +473,14 @@ void moe_sum_opt1(torch::Tensor& input, // [num_tokens, topk, hidden_size]
output.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(),
hidden_size);
});
break;
case 9:
VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "moe_sum_sharedmem", [&]{
vllm::moe::moe_sum_sharedmem<scalar_t, 9, 9, 256><<<num_tokens * 9, 256, 0, stream>>>(
output.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(),
hidden_size);
});
break;
default:
......
......@@ -27,7 +27,7 @@ static inline __device__ float fp8_to_float(uint8_t input) {
}
// float -> fp8
static inline __device__ uint8_t float_to_fp8(float f) {
static inline __device__ uint8_t float_to_fp8_e4m3(float f) {
constexpr uint32_t fp8_max = UINT32_C(1087) << 20;
constexpr uint32_t denorm_mask = UINT32_C(141) << 23;
uint32_t f_bits = c10::detail::fp32_to_bits(f);
......@@ -53,10 +53,35 @@ static inline __device__ uint8_t float_to_fp8(float f) {
return result;
}
static inline __device__ uint8_t float_to_fp8_e5m2(float f) {
constexpr uint32_t fp32_inf = UINT32_C(255) << 23;
constexpr uint32_t fp8_max = UINT32_C(143) << 23;
constexpr uint32_t denorm_mask = UINT32_C(134) << 23;
uint32_t f_bits = c10::detail::fp32_to_bits(f);
uint8_t result = 0u;
const uint32_t sign = f_bits & UINT32_C(0x80000000);
f_bits ^= sign;
if (f_bits >= fp8_max) {
result = f_bits > fp32_inf ? UINT8_C(0x7F) : UINT8_C(0x7C);
} else {
if (f_bits < (UINT32_C(113) << 23)) {
f_bits = c10::detail::fp32_to_bits(c10::detail::fp32_from_bits(f_bits)
+ c10::detail::fp32_from_bits(denorm_mask));
result = static_cast<uint8_t>(f_bits - denorm_mask);
} else {
uint32_t mant_odd = (f_bits >> 21) & 1;
f_bits += ((uint32_t)(15 - 127) << 23) + 0xFFFFF;
f_bits += mant_odd;
result = static_cast<uint8_t>(f_bits >> 21);
}
}
result |= static_cast<uint8_t>(sign >> 24);
return result;
}
template <typename Tout, typename Tin>
__inline__ __device__ Tout scaled_vec_conversion(const Tin& x,
const float scale) {
const float scale, Fp8KVCacheDataType kv_type) {
return x;
}
......@@ -65,8 +90,10 @@ using __nv_bfloat16 = __hip_bfloat16;
// fp8 -> __nv_bfloat16
template <>
__inline__ __device__ __nv_bfloat16
scaled_vec_conversion<__nv_bfloat16, uint8_t>(const uint8_t& a, float scale) {
scaled_vec_conversion<__nv_bfloat16, uint8_t>(const uint8_t& a, float scale, Fp8KVCacheDataType kv_type) {
if (kv_type == vllm::Fp8KVCacheDataType::kFp8E5M2) {
assert(false);
}
return __float2bfloat16(fp8_to_float(a) * scale);
}
......@@ -74,32 +101,32 @@ scaled_vec_conversion<__nv_bfloat16, uint8_t>(const uint8_t& a, float scale) {
template <>
__inline__ __device__ __nv_bfloat162
scaled_vec_conversion<__nv_bfloat162, uint16_t>(const uint16_t& a,
float scale) {
float scale, Fp8KVCacheDataType kv_type) {
__nv_bfloat162 res;
res.x = scaled_vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)a, scale);
res.x = scaled_vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)a, scale, kv_type);
res.y =
scaled_vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)(a >> 8U), scale);
scaled_vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)(a >> 8U), scale, kv_type);
return res;
}
// fp8x4 -> bf16_4_t
template <>
__inline__ __device__ bf16_4_t
scaled_vec_conversion<bf16_4_t, uint32_t>(const uint32_t& a, float scale) {
scaled_vec_conversion<bf16_4_t, uint32_t>(const uint32_t& a, float scale, Fp8KVCacheDataType kv_type) {
bf16_4_t res;
res.x = scaled_vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)a, scale);
res.x = scaled_vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)a, scale, kv_type);
res.y = scaled_vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)(a >> 16U),
scale);
scale, kv_type);
return res;
}
// fp8x8 -> bf16_8_t
template <>
__inline__ __device__ bf16_8_t
scaled_vec_conversion<bf16_8_t, uint2>(const uint2& a, float scale) {
scaled_vec_conversion<bf16_8_t, uint2>(const uint2& a, float scale, Fp8KVCacheDataType kv_type) {
bf16_4_t tmp1, tmp2;
tmp1 = scaled_vec_conversion<bf16_4_t, uint32_t>(a.x, scale);
tmp2 = scaled_vec_conversion<bf16_4_t, uint32_t>(a.y, scale);
tmp1 = scaled_vec_conversion<bf16_4_t, uint32_t>(a.x, scale, kv_type);
tmp2 = scaled_vec_conversion<bf16_4_t, uint32_t>(a.y, scale, kv_type);
bf16_8_t res;
res.x = tmp1.x;
res.y = tmp1.y;
......@@ -111,45 +138,48 @@ scaled_vec_conversion<bf16_8_t, uint2>(const uint2& a, float scale) {
// fp8 -> float
template <>
__inline__ __device__ float scaled_vec_conversion<float, uint8_t>(
const uint8_t& a, float scale) {
const uint8_t& a, float scale, Fp8KVCacheDataType kv_type) {
if (kv_type == vllm::Fp8KVCacheDataType::kFp8E5M2) {
assert(false);
}
return fp8_to_float(a) * scale;
}
// fp8x2 -> float2
template <>
__inline__ __device__ float2
scaled_vec_conversion<float2, uint16_t>(const uint16_t& a, float scale) {
scaled_vec_conversion<float2, uint16_t>(const uint16_t& a, float scale, Fp8KVCacheDataType kv_type) {
float2 f2r;
f2r.x = scaled_vec_conversion<float, uint8_t>((uint8_t)a, scale);
f2r.y = scaled_vec_conversion<float, uint8_t>((uint8_t)(a >> 8U), scale);
f2r.x = scaled_vec_conversion<float, uint8_t>((uint8_t)a, scale, kv_type);
f2r.y = scaled_vec_conversion<float, uint8_t>((uint8_t)(a >> 8U), scale, kv_type);
return f2r;
}
// fp8x4 -> float4
template <>
__inline__ __device__ Float4_
scaled_vec_conversion<Float4_, uint32_t>(const uint32_t& a, const float scale) {
scaled_vec_conversion<Float4_, uint32_t>(const uint32_t& a, const float scale, Fp8KVCacheDataType kv_type) {
Float4_ res;
res.x = scaled_vec_conversion<float2, uint16_t>((uint16_t)a, scale);
res.y = scaled_vec_conversion<float2, uint16_t>((uint16_t)(a >> 16U), scale);
res.x = scaled_vec_conversion<float2, uint16_t>((uint16_t)a, scale, kv_type);
res.y = scaled_vec_conversion<float2, uint16_t>((uint16_t)(a >> 16U), scale, kv_type);
return res;
}
// fp8x4 -> float4
template <>
__inline__ __device__ float4
scaled_vec_conversion<float4, uint32_t>(const uint32_t& a, float scale) {
Float4_ res = scaled_vec_conversion<Float4_, uint32_t>(a, scale);
scaled_vec_conversion<float4, uint32_t>(const uint32_t& a, float scale, Fp8KVCacheDataType kv_type) {
Float4_ res = scaled_vec_conversion<Float4_, uint32_t>(a, scale, kv_type);
return {res.x.x, res.x.y, res.y.x, res.y.y};
}
// fp8x8 -> float8
template <>
__inline__ __device__ Float8_
scaled_vec_conversion<Float8_, uint2>(const uint2& a, float scale) {
scaled_vec_conversion<Float8_, uint2>(const uint2& a, float scale, Fp8KVCacheDataType kv_type) {
Float4_ tmp1, tmp2;
tmp1 = scaled_vec_conversion<Float4_, uint32_t>(a.x, scale);
tmp2 = scaled_vec_conversion<Float4_, uint32_t>(a.y, scale);
tmp1 = scaled_vec_conversion<Float4_, uint32_t>(a.x, scale, kv_type);
tmp2 = scaled_vec_conversion<Float4_, uint32_t>(a.y, scale, kv_type);
Float8_ res;
res.x = tmp1.x;
res.y = tmp1.y;
......@@ -161,7 +191,10 @@ scaled_vec_conversion<Float8_, uint2>(const uint2& a, float scale) {
// fp8 -> half
template <>
__inline__ __device__ uint16_t
scaled_vec_conversion<uint16_t, uint8_t>(const uint8_t& a, float scale) {
scaled_vec_conversion<uint16_t, uint8_t>(const uint8_t& a, float scale, Fp8KVCacheDataType kv_type) {
if (kv_type == vllm::Fp8KVCacheDataType::kFp8E5M2) {
assert(false);
}
float res = fp8_to_float(a) * scale;
return float_to_half(res);
}
......@@ -169,54 +202,58 @@ scaled_vec_conversion<uint16_t, uint8_t>(const uint8_t& a, float scale) {
// fp8x2 -> half2
template <>
__inline__ __device__ uint32_t
scaled_vec_conversion<uint32_t, uint16_t>(const uint16_t& a, float scale) {
scaled_vec_conversion<uint32_t, uint16_t>(const uint16_t& a, float scale, Fp8KVCacheDataType kv_type) {
union {
uint16_t u16[2];
uint32_t u32;
} res;
res.u16[0] = scaled_vec_conversion<uint16_t, uint8_t>((uint8_t)a, scale);
res.u16[1] = scaled_vec_conversion<uint16_t, uint8_t>((uint8_t)(a >> 8U), scale);
res.u16[0] = scaled_vec_conversion<uint16_t, uint8_t>((uint8_t)a, scale, kv_type);
res.u16[1] = scaled_vec_conversion<uint16_t, uint8_t>((uint8_t)(a >> 8U), scale, kv_type);
return res.u32;
}
// fp8x4 -> half2x2
template <>
__inline__ __device__ uint2
scaled_vec_conversion<uint2, uint32_t>(const uint32_t& a, float scale) {
scaled_vec_conversion<uint2, uint32_t>(const uint32_t& a, float scale, Fp8KVCacheDataType kv_type) {
union {
uint2 u32x2;
uint32_t u32[2];
} tmp;
tmp.u32[0] = scaled_vec_conversion<uint32_t, uint16_t>((uint16_t)a, scale);
tmp.u32[1] = scaled_vec_conversion<uint32_t, uint16_t>((uint16_t)(a >> 16U), scale);
tmp.u32[0] = scaled_vec_conversion<uint32_t, uint16_t>((uint16_t)a, scale, kv_type);
tmp.u32[1] = scaled_vec_conversion<uint32_t, uint16_t>((uint16_t)(a >> 16U), scale, kv_type);
return tmp.u32x2;
}
// fp8x8 -> half2x4
template <>
__inline__ __device__ uint4 scaled_vec_conversion<uint4, uint2>(const uint2& a,
float scale) {
float scale, Fp8KVCacheDataType kv_type) {
union {
uint4 u64x2;
uint2 u64[2];
} tmp;
tmp.u64[0] = scaled_vec_conversion<uint2, uint32_t>(a.x, scale);
tmp.u64[1] = scaled_vec_conversion<uint2, uint32_t>(a.y, scale);
tmp.u64[0] = scaled_vec_conversion<uint2, uint32_t>(a.x, scale, kv_type);
tmp.u64[1] = scaled_vec_conversion<uint2, uint32_t>(a.y, scale, kv_type);
return tmp.u64x2;
}
// half -> fp8
template <>
__inline__ __device__ uint8_t
scaled_vec_conversion<uint8_t, uint16_t>(const uint16_t& a, float scale) {
scaled_vec_conversion<uint8_t, uint16_t>(const uint16_t& a, float scale, Fp8KVCacheDataType kv_type) {
float res_f = half_to_float(a) / scale;
return float_to_fp8(res_f);
if (kv_type == vllm::Fp8KVCacheDataType::kFp8E4M3) {
return float_to_fp8_e4m3(res_f);
} else {
return float_to_fp8_e5m2(res_f);
}
}
// halfx2 -> fp8x2
template <>
__inline__ __device__ uint16_t
scaled_vec_conversion<uint16_t, uint32_t>(const uint32_t& a, float scale) {
scaled_vec_conversion<uint16_t, uint32_t>(const uint32_t& a, float scale, Fp8KVCacheDataType kv_type) {
union {
uint8_t ui8[2];
uint16_t ui16;
......@@ -226,113 +263,122 @@ scaled_vec_conversion<uint16_t, uint32_t>(const uint32_t& a, float scale) {
half2 h2r;
} tmp_a;
tmp_a.ui32 = a;
tmp.ui8[0] = scaled_vec_conversion<uint8_t, uint16_t>(tmp_a.h2r.data[0], scale);
tmp.ui8[1] = scaled_vec_conversion<uint8_t, uint16_t>(tmp_a.h2r.data[1], scale);
tmp.ui8[0] = scaled_vec_conversion<uint8_t, uint16_t>(tmp_a.h2r.data[0], scale, kv_type);
tmp.ui8[1] = scaled_vec_conversion<uint8_t, uint16_t>(tmp_a.h2r.data[1], scale, kv_type);
return tmp.ui16;
}
// half2x2 -> fp8x4
template <>
__inline__ __device__ uint32_t
scaled_vec_conversion<uint32_t, uint2>(const uint2& a, float scale) {
scaled_vec_conversion<uint32_t, uint2>(const uint2& a, float scale, Fp8KVCacheDataType kv_type) {
union {
uint16_t ui16[2];
uint32_t ui32;
} tmp;
tmp.ui16[0] = scaled_vec_conversion<uint16_t, uint32_t>(a.x, scale);
tmp.ui16[1] = scaled_vec_conversion<uint16_t, uint32_t>(a.y, scale);
tmp.ui16[0] = scaled_vec_conversion<uint16_t, uint32_t>(a.x, scale, kv_type);
tmp.ui16[1] = scaled_vec_conversion<uint16_t, uint32_t>(a.y, scale, kv_type);
return tmp.ui32;
}
// half2x4 -> fp8x8
template <>
__inline__ __device__ uint2 scaled_vec_conversion<uint2, uint4>(const uint4& a,
float scale) {
float scale, Fp8KVCacheDataType kv_type) {
union {
uint2 ui2[2];
uint4 ui4;
} tmp;
tmp.ui4 = a;
uint2 res;
res.x = scaled_vec_conversion<uint32_t, uint2>(tmp.ui2[0], scale);
res.y = scaled_vec_conversion<uint32_t, uint2>(tmp.ui2[1], scale);
res.x = scaled_vec_conversion<uint32_t, uint2>(tmp.ui2[0], scale, kv_type);
res.y = scaled_vec_conversion<uint32_t, uint2>(tmp.ui2[1], scale, kv_type);
return res;
}
// bf16 -> fp8
template <>
__inline__ __device__ uint8_t scaled_vec_conversion<uint8_t, __nv_bfloat16>(
const __nv_bfloat16& a, float scale) {
const __nv_bfloat16& a, float scale, Fp8KVCacheDataType kv_type) {
float res_f = (static_cast<float>(a)) / scale;
return float_to_fp8(res_f);
if (kv_type == vllm::Fp8KVCacheDataType::kFp8E4M3) {
return float_to_fp8_e4m3(res_f);
} else {
return float_to_fp8_e5m2(res_f);
}
}
// bf16x2 -> fp8x2
template <>
__inline__ __device__ uint16_t scaled_vec_conversion<uint16_t, __nv_bfloat162>(
const __nv_bfloat162& a, float scale) {
const __nv_bfloat162& a, float scale, Fp8KVCacheDataType kv_type) {
union {
uint8_t ui8[2];
uint16_t ui16;
} tmp;
tmp.ui8[0] = scaled_vec_conversion<uint8_t, __nv_bfloat16>(a.x, scale);
tmp.ui8[1] = scaled_vec_conversion<uint8_t, __nv_bfloat16>(a.y, scale);
tmp.ui8[0] = scaled_vec_conversion<uint8_t, __nv_bfloat16>(a.x, scale, kv_type);
tmp.ui8[1] = scaled_vec_conversion<uint8_t, __nv_bfloat16>(a.y, scale, kv_type);
return tmp.ui16;
}
// bf16x4 -> fp8x4
template <>
__inline__ __device__ uint32_t
scaled_vec_conversion<uint32_t, bf16_4_t>(const bf16_4_t& a, float scale) {
scaled_vec_conversion<uint32_t, bf16_4_t>(const bf16_4_t& a, float scale, Fp8KVCacheDataType kv_type) {
union {
uint16_t ui16[2];
uint32_t ui32;
} tmp;
tmp.ui16[0] = scaled_vec_conversion<uint16_t, __nv_bfloat162>(a.x, scale);
tmp.ui16[1] = scaled_vec_conversion<uint16_t, __nv_bfloat162>(a.y, scale);
tmp.ui16[0] = scaled_vec_conversion<uint16_t, __nv_bfloat162>(a.x, scale, kv_type);
tmp.ui16[1] = scaled_vec_conversion<uint16_t, __nv_bfloat162>(a.y, scale, kv_type);
return tmp.ui32;
}
// bf16x8 -> fp8x8
template <>
__inline__ __device__ uint2
scaled_vec_conversion<uint2, bf16_8_t>(const bf16_8_t& a, float scale) {
scaled_vec_conversion<uint2, bf16_8_t>(const bf16_8_t& a, float scale, Fp8KVCacheDataType kv_type) {
uint2 res;
res.x = scaled_vec_conversion<uint32_t, bf16_4_t>({a.x, a.y}, scale);
res.y = scaled_vec_conversion<uint32_t, bf16_4_t>({a.z, a.w}, scale);
res.x = scaled_vec_conversion<uint32_t, bf16_4_t>({a.x, a.y}, scale, kv_type);
res.y = scaled_vec_conversion<uint32_t, bf16_4_t>({a.z, a.w}, scale, kv_type);
return res;
}
// float -> fp8
template <>
__inline__ __device__ uint8_t
scaled_vec_conversion<uint8_t, float>(const float& a, float scale) {
return float_to_fp8(a / scale);
scaled_vec_conversion<uint8_t, float>(const float& a, float scale, Fp8KVCacheDataType kv_type) {
if (kv_type == vllm::Fp8KVCacheDataType::kFp8E4M3) {
return float_to_fp8_e4m3(a / scale);
} else {
return float_to_fp8_e5m2(a / scale);
}
}
// floatx2 -> fp8x2
template <>
__inline__ __device__ uint16_t
scaled_vec_conversion<uint16_t, float2>(const float2& a, float scale) {
scaled_vec_conversion<uint16_t, float2>(const float2& a, float scale, Fp8KVCacheDataType kv_type) {
union {
uint8_t ui8[2];
uint16_t ui16;
} tmp;
tmp.ui8[0] = scaled_vec_conversion<uint8_t, float>(a.x, scale);
tmp.ui8[1] = scaled_vec_conversion<uint8_t, float>(a.y, scale);
tmp.ui8[0] = scaled_vec_conversion<uint8_t, float>(a.x, scale, kv_type);
tmp.ui8[1] = scaled_vec_conversion<uint8_t, float>(a.y, scale, kv_type);
return tmp.ui16;
}
// floatx4 -> fp8x4
template <>
__inline__ __device__ uint32_t
scaled_vec_conversion<uint32_t, float4>(const float4& a, float scale) {
scaled_vec_conversion<uint32_t, float4>(const float4& a, float scale, Fp8KVCacheDataType kv_type) {
union {
uint16_t ui16[2];
uint32_t ui32;
} tmp;
tmp.ui16[0] = scaled_vec_conversion<uint16_t, float2>({a.x, a.y}, scale);
tmp.ui16[1] = scaled_vec_conversion<uint16_t, float2>({a.z, a.w}, scale);
tmp.ui16[0] = scaled_vec_conversion<uint16_t, float2>({a.x, a.y}, scale, kv_type);
tmp.ui16[1] = scaled_vec_conversion<uint16_t, float2>({a.z, a.w}, scale, kv_type);
return tmp.ui32;
}
......@@ -433,8 +479,8 @@ scaled_vec_conversion_from_e5m2<__nv_bfloat16>(const uint8_t& a, float scale) {
template <typename Tout, typename Tin, Fp8KVCacheDataType kv_dt>
__inline__ __device__ Tout scaled_convert(const Tin& x, const float scale) {
if constexpr (kv_dt == Fp8KVCacheDataType::kFp8E4M3) {
return scaled_vec_conversion<Tout, Tin>(x, scale);
if constexpr (kv_dt == Fp8KVCacheDataType::kFp8E4M3 || kv_dt == Fp8KVCacheDataType::kFp8E5M2) {
return scaled_vec_conversion<Tout, Tin>(x, scale, kv_dt);
}
else if constexpr(kv_dt == Fp8KVCacheDataType::kFp8E5M2 && sizeof(Tout)==1){
return scaled_vec_conversion_to_e5m2<Tin>(x, scale);
......
......@@ -11,6 +11,8 @@ import subprocess
import sys
from pathlib import Path
from shutil import which
import tarfile
import shutil
import torch
from packaging.version import Version, parse
......@@ -31,6 +33,39 @@ skip_vllm_build = False
if int(os.environ.get('SKIP_VLLM_BUILD', '0')) == 1:
skip_vllm_build = True
def prepare_so_files():
source_dir = "so.tar.gz"
target_dir = "vllm"
if not os.path.exists(source_dir):
print(f"Warning: {source_dir} not found, skipping extraction")
return
print(f"Preparing C extension files from {source_dir}...")
temp_dir = "temp_so_extract"
os.makedirs(temp_dir, exist_ok=True)
try:
with tarfile.open(source_dir, "r:*") as tar:
tar.extractall(temp_dir)
for root, dirs, files in os.walk(temp_dir):
for file in files:
if file in ["_C.abi3.so", "_moe_C.abi3.so"]:
src_path = os.path.join(root, file)
dst_path = os.path.join(target_dir, file)
os.makedirs(os.path.dirname(dst_path), exist_ok=True)
shutil.copy2(src_path, dst_path)
print(f"Copied {file} to {dst_path}")
finally:
if os.path.exists(temp_dir):
shutil.rmtree(temp_dir)
def load_module_from_path(module_name, path):
spec = importlib.util.spec_from_file_location(module_name, path)
module = importlib.util.module_from_spec(spec)
......@@ -559,10 +594,10 @@ def get_version_add(sha: Optional[str] = None) -> str:
if sha is None:
sha = get_sha(vllm_root)
if (major, minor) >= ('2', '5'):
version = 'das.opt3.' + sha[:7]
version = 'das.opt5.' + sha[:7]
else:
if (major, minor) >= ('2', '5'):
version = 'das.opt3'
version = 'das.opt5'
# dtk version
......@@ -769,6 +804,7 @@ if _build_custom_ops():
ext_modules.append(CMakeExtension(name="vllm._C"))
if skip_vllm_build:
prepare_so_files()
package_data = {
"vllm": [
"py.typed",
......
......@@ -14,6 +14,7 @@ from vllm.utils import direct_register_custom_op
try:
from lmslim import quant_ops
from lmslim import quant_tools
from lmslim.layers.gemm.fp8_utils import per_token_quant_fp8
except Exception:
print("INFO: Please install lmslim if you want to infer gptq or awq or w8a8 model.\n")
try:
......@@ -1692,66 +1693,67 @@ def scaled_fp4_experts_quant(
# fp8
# def scaled_fp8_quant(
# input: torch.Tensor,
# scale: Optional[torch.Tensor] = None,
# num_token_padding: Optional[int] = None,
# scale_ub: Optional[torch.Tensor] = None,
# use_per_token_if_dynamic: bool = False,
# output: Optional[torch.Tensor] = None,
# ) -> tuple[torch.Tensor, torch.Tensor]:
# """
# Quantize input tensor to FP8 and return quantized tensor and scale.
# This function supports both static and dynamic quantization: If you
# provide the scale, it will use static scaling and if you omit it,
# the scale will be determined dynamically. The function also allows
# optional padding of the output tensors for downstream kernels that
# will benefit from padding.
# Args:
# input: The input tensor to be quantized to FP8
# scale: Optional scaling factor for the FP8 quantization
# scale_ub: Optional upper bound for scaling factor in dynamic
# per token case
# num_token_padding: If specified, pad the first dimension
# of the output to at least this value.
# use_per_token_if_dynamic: Whether to do per_tensor or per_token
# in the dynamic quantization case.
# Returns:
# tuple[torch.Tensor, torch.Tensor]: The output tensor in FP8 and
# scaling factor.
# """
# # This code assumes batch_dim and num_tokens are flattened
# assert (input.ndim == 2)
# shape: Union[tuple[int, int], torch.Size] = input.shape
# # For ROCm on MI300, the output fp8 dtype is torch.float_e3m3fnuz
# out_dtype: torch.dtype = current_platform.fp8_dtype()
# if num_token_padding:
# shape = (max(num_token_padding, input.shape[0]), shape[1])
# if output is None:
# output = torch.empty(shape, device=input.device, dtype=out_dtype)
# else:
# assert num_token_padding is None, \
# "padding not supported if output passed in"
# assert output.dtype == out_dtype
# if scale is None:
# if use_per_token_if_dynamic:
# scale = torch.empty((shape[0], 1),
# device=input.device,
# dtype=torch.float32)
# torch.ops._C.dynamic_per_token_scaled_fp8_quant(
# output, input.contiguous(), scale, scale_ub)
# else:
# scale = torch.zeros(1, device=input.device, dtype=torch.float32)
# torch.ops._C.dynamic_scaled_fp8_quant(output, input, scale)
# else:
# assert scale.numel() == 1, f"{scale.shape}"
# torch.ops._C.static_scaled_fp8_quant(output, input, scale)
# return output, scale
def scaled_fp8_quant(
input: torch.Tensor,
scale: Optional[torch.Tensor] = None,
num_token_padding: Optional[int] = None,
scale_ub: Optional[torch.Tensor] = None,
use_per_token_if_dynamic: bool = False,
output: Optional[torch.Tensor] = None,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Quantize input tensor to FP8 and return quantized tensor and scale.
This function supports both static and dynamic quantization: If you
provide the scale, it will use static scaling and if you omit it,
the scale will be determined dynamically. The function also allows
optional padding of the output tensors for downstream kernels that
will benefit from padding.
Args:
input: The input tensor to be quantized to FP8
scale: Optional scaling factor for the FP8 quantization
scale_ub: Optional upper bound for scaling factor in dynamic
per token case
num_token_padding: If specified, pad the first dimension
of the output to at least this value.
use_per_token_if_dynamic: Whether to do per_tensor or per_token
in the dynamic quantization case.
Returns:
tuple[torch.Tensor, torch.Tensor]: The output tensor in FP8 and
scaling factor.
"""
# This code assumes batch_dim and num_tokens are flattened
assert (input.ndim == 2)
shape: Union[tuple[int, int], torch.Size] = input.shape
# For ROCm on MI300, the output fp8 dtype is torch.float_e3m3fnuz
out_dtype: torch.dtype = current_platform.fp8_dtype()
if num_token_padding:
shape = (max(num_token_padding, input.shape[0]), shape[1])
if output is None:
output = torch.empty(shape, device=input.device, dtype=out_dtype)
else:
assert num_token_padding is None, \
"padding not supported if output passed in"
assert output.dtype == out_dtype
if scale is None:
if use_per_token_if_dynamic:
scale = torch.empty((shape[0], 1),
device=input.device,
dtype=torch.float32)
# torch.ops._C.dynamic_per_token_scaled_fp8_quant(
# output, input.contiguous(), scale, scale_ub)
output, scale = per_token_quant_fp8(input.contiguous())
else:
scale = torch.zeros(1, device=input.device, dtype=torch.float32)
torch.ops._C.dynamic_scaled_fp8_quant(output, input, scale)
else:
assert scale.numel() == 1, f"{scale.shape}"
torch.ops._C.static_scaled_fp8_quant(output, input, scale)
return output, scale
# gptq allspark
......
......@@ -204,6 +204,8 @@ class Attention(nn.Module):
# shape does not match the query shape, so we optionally let the model
# definition specify the output tensor shape.
output_shape: Optional[torch.Size] = None,
query_nope: Optional[torch.Size] = None,
num_local_heads: Optional[int] = None,
q_ori: Optional[torch.Tensor] = None,
key_normed: Optional[torch.Tensor] = None,
positions: Optional[torch.Tensor] = None,
......@@ -270,7 +272,7 @@ class Attention(nn.Module):
query, key, value, output, self.layer_name)
else:
torch.ops.vllm.unified_attention_with_output(
query, key, value, output, self.layer_name, None, q_ori, key_normed, positions, weight, cos_sin_cache)
query, key, value, output, self.layer_name, None, query_nope, num_local_heads, q_ori, key_normed, positions, weight, cos_sin_cache)
return output.view(-1, hidden_size)
else:
if self.use_direct_call:
......@@ -511,6 +513,8 @@ def unified_attention_with_output(
output: torch.Tensor,
layer_name: str,
output_scale: Optional[torch.Tensor] = None,
query_nope: Optional[torch.Tensor] = None,
num_local_heads: Optional[int] = None,
q_ori: Optional[torch.Tensor] = None,
key_normed: Optional[torch.Tensor] = None,
positions: Optional[torch.Tensor] = None,
......@@ -542,6 +546,8 @@ def unified_attention_with_output(
attn_metadata,
output=output,
output_scale=output_scale,
query_nope=query_nope,
num_local_heads=num_local_heads,
q_ori=q_ori,
key_normed=key_normed,
positions=positions,
......@@ -560,6 +566,8 @@ def unified_attention_with_output_fake(
output: torch.Tensor,
layer_name: str,
output_scale: Optional[torch.Tensor] = None,
query_nope: Optional[torch.Tensor] = None,
num_local_heads: Optional[int] = None,
q_ori: Optional[torch.Tensor] = None,
key_normed: Optional[torch.Tensor] = None,
positions: Optional[torch.Tensor] = None,
......
......@@ -277,6 +277,60 @@ def flash_mla_with_kvcache_fp8(
)
return out, softmax_lse
def flash_mla_with_kvcache_fp8_with_cat(
q_nope: torch.Tensor,
q_pe: torch.Tensor,
k_cache: torch.Tensor,
block_table: torch.Tensor,
cache_seqlens: torch.Tensor,
head_dim_v: int,
tile_scheduler_metadata: torch.Tensor,
num_splits: torch.Tensor,
softmax_scale: Optional[float] = None,
causal: bool = False,
descale_q: Optional[torch.Tensor] = None,
descale_k: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Arguments:
q_nope: (batch_size, seq_len_q, num_heads_q, 512).
q_pe: (batch_size, seq_len_q, num_heads_q, 64).
k_cache: (num_blocks, page_block_size, num_heads_k, head_dim).
block_table: (batch_size, max_num_blocks_per_seq), torch.int32.
cache_seqlens: (batch_size), torch.int32.
head_dim_v: Head dimension of v.
tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize), torch.int32, returned by get_mla_metadata.
num_splits: (batch_size + 1), torch.int32, returned by get_mla_metadata.
softmax_scale: float. The scale of QK^T before applying softmax. Default to 1 / sqrt(head_dim).
causal: bool. Whether to apply causal attention mask.
descale_q: (batch_size), torch.float32. Descaling factors for Q, used for fp8 quantization.
descale_k: (batch_size), torch.float32. Descaling factors for K, used for fp8 quantization.
Returns:
out: (batch_size, seq_len_q, num_heads_q, head_dim_v).
softmax_lse: (batch_size, num_heads_q, seq_len_q), torch.float32.
"""
if softmax_scale is None:
softmax_scale = (q_nope.shape[-1] + q_pe.shape[-1]) ** (-0.5)
out, softmax_lse = flash_mla_cuda.fwd_kvcache_mla_fp8_with_cat(
q_nope,
q_pe,
k_cache,
None,
head_dim_v,
cache_seqlens,
block_table,
softmax_scale,
causal,
tile_scheduler_metadata,
num_splits,
descale_q,
descale_k,
)
return out, softmax_lse
#
# TODO: Add fake functions
#
......
......@@ -5,6 +5,7 @@ from typing import Optional
from vllm import envs
from vllm.logger import init_logger
from vllm.platforms import current_platform
import torch
logger = init_logger(__name__)
......@@ -68,6 +69,8 @@ def get_flash_attn_version(requires_alibi: bool = False) -> Optional[int]:
def flash_attn_supports_fp8() -> bool:
if current_platform.is_rocm():
return True
return get_flash_attn_version() == 3 and \
current_platform.get_device_capability().major == 9
......
......@@ -1883,6 +1883,9 @@ class ParallelConfig:
""" Use data parallelism instead of tensor parallelism for vision encoder.
Only support LLama4 for now"""
enable_dp_attention: bool = False
"""Enable dp attention"""
@property
def world_size_across_dp(self) -> int:
"""world_size_across_dp is TPxPPxDP, it is the size of the world
......@@ -1939,6 +1942,24 @@ class ParallelConfig:
assert last_exc is not None
raise last_exc
# The all_reduce at the end of attention (during o_proj) means that
# inputs are replicated across each rank of the tensor parallel group.
# If using expert-parallelism with DeepEP All2All ops, replicated
# tokens results in useless duplicate computation and communication.
#
# In this case, ensure the input to the experts is sequence parallel
# to avoid the excess work.
#
# Not needed for pplx-kernels as it can handle duplicate input tokens.
@property
def use_sequence_parallel_moe(self) -> bool:
return (envs.VLLM_ALL2ALL_BACKEND
in ("allgather_reducescatter", "naive",
"deepep_high_throughput", "deepep_low_latency")
and self.enable_expert_parallel
and self.tensor_parallel_size > 1
and self.data_parallel_size > 1)
@staticmethod
def has_unfinished_dp(dp_group: "ProcessGroup",
has_unfinished: bool) -> bool:
......@@ -2091,6 +2112,9 @@ class ParallelConfig:
raise ValueError("Unable to use nsight profiling unless workers "
"run with Ray.")
if self.enable_dp_attention and self.enable_expert_parallel:
raise ValueError("Dp attention and expert parallel can not enable together.")
return self
......@@ -4787,6 +4811,7 @@ class VllmConfig:
dp_size = self.parallel_config.data_parallel_size
tp_size = self.parallel_config.tensor_parallel_size
ep_sp = self.parallel_config.enable_expert_parallel and dp_size > 1 and tp_size > 1
enable_dp_attention = self.parallel_config.enable_dp_attention
# add for spec decode
if self.speculative_config is not None and self.speculative_config.num_lookahead_slots > 0:
......@@ -4795,10 +4820,10 @@ class VllmConfig:
batch_size_capture_list = sorted(set(batch_size_capture_list + mtp_batch_size_capture_list))
batch_size_capture_list = [i for i in batch_size_capture_list if i == 1 or i % (1 + self.speculative_config.num_lookahead_slots) == 0]
if ep_sp:
if ep_sp or enable_dp_attention:
batch_size_capture_list = sorted(set([round_up(i, tp_size) for i in batch_size_capture_list]))
else:
if ep_sp:
if ep_sp or enable_dp_attention:
batch_size_capture_list = sorted(set([round_up(i, tp_size) for i in batch_size_capture_list]))
self.compilation_config.init_with_cudagraph_sizes(
......
......@@ -194,6 +194,7 @@ class DeepEPHTAll2AllManager(DeepEPAll2AllManagerBase):
num_rdma_bytes=num_rdma_bytes,
low_latency_mode=False,
num_qps_per_rank=num_qps_per_rank,
allow_mnnvl=envs.VLLM_ALLOW_MNNVL,
explicitly_destroy=False)
def get_handle(self, kwargs):
......@@ -256,6 +257,7 @@ class DeepEPLLAll2AllManager(DeepEPAll2AllManagerBase):
num_rdma_bytes=num_rdma_bytes,
low_latency_mode=True,
num_qps_per_rank=num_qps_per_rank,
allow_mnnvl=envs.VLLM_ALLOW_MNNVL,
)
def get_handle(self, kwargs):
......@@ -274,3 +276,56 @@ class DeepEPLLAll2AllManager(DeepEPAll2AllManagerBase):
# in get_or_create must be updated.
handle.set_num_sms(self.num_sms)
return handle
class DeepEPAutoAll2AllManager(All2AllManagerBase):
"""
Simplified auto manager that always builds handles through the
low-latency DeepEP manager. This avoids creating multiple buffer
instances and mirrors the sglang behavior of relying on LL buffers.
"""
def __init__(self, cpu_group):
super().__init__(cpu_group)
self.ll_manager = DeepEPLLAll2AllManager(cpu_group)
self.ht_manager = DeepEPHTAll2AllManager(cpu_group)
def get_handle(self, kwargs):
"""
Build a DeepEP Buffer using LL args but sized to the larger of HT/LL
requirements (max of num_nvl_bytes/num_rdma_bytes).
"""
import deep_ep
kwargs = dict(kwargs)
# Build canonical kwargs for each path.
ll_kwargs = self.ll_manager._make_all2all_kwargs(**kwargs)
ht_kwargs = self.ht_manager._make_all2all_kwargs()
# Take the max for buffer sizes to be compatible with both modes.
merged_kwargs = dict(ll_kwargs)
merged_kwargs["num_nvl_bytes"] = max(ll_kwargs["num_nvl_bytes"],
ht_kwargs["num_nvl_bytes"])
merged_kwargs["num_rdma_bytes"] = max(ll_kwargs["num_rdma_bytes"],
ht_kwargs["num_rdma_bytes"])
logger.debug("DeepEP auto merged args %s", merged_kwargs)
handle: deep_ep.Buffer = self.ll_manager.handle_cache.get_or_create(
merged_kwargs, deep_ep.Buffer)
handle.set_num_sms(self.ll_manager.num_sms)
return handle
def dispatch(self, hidden_states: torch.Tensor,
router_logits: torch.Tensor):
raise NotImplementedError(
"DeepEPAutoAll2AllManager does not support dispatch directly; "
"use the underlying HT/LL managers.")
def combine(self, hidden_states: torch.Tensor) -> torch.Tensor:
raise NotImplementedError(
"DeepEPAutoAll2AllManager does not support combine directly; "
"use the underlying HT/LL managers.")
def destroy(self):
self.ll_manager.destroy()
......@@ -103,7 +103,7 @@ class DeviceCommunicatorBase:
# as long as we use data parallel (coupled data parallel
# where all data parallel ranks execute forward together),
# we initialize the all2all manager used in expert parallel.
use_ep = config.parallel_config.data_parallel_size > 1
use_ep = config.parallel_config.data_parallel_size > 1 and not config.parallel_config.enable_dp_attention
self.use_all2all = "ep" in unique_name and use_ep
self.all2all_manager: Optional[All2AllManagerBase] = None
......
......@@ -88,6 +88,10 @@ class CudaCommunicator(DeviceCommunicatorBase):
from .all2all import DeepEPLLAll2AllManager
self.all2all_manager = DeepEPLLAll2AllManager(self.cpu_group)
logger.info("Using DeepEP Low-Latency all2all manager.")
elif all2all_backend == "deepep_auto":
from .all2all import DeepEPAutoAll2AllManager
self.all2all_manager = DeepEPAutoAll2AllManager(self.cpu_group)
logger.info("Using DeepEP Auto all2all manager.")
else:
raise ValueError(f"Unknown all2all backend: {all2all_backend}")
......
......@@ -121,7 +121,11 @@ class CustomAllreduce:
else:
device_ids = list(range(cuda_device_count_stateless()))
if (world_size == len(device_ids)):
physical_device_id = device_ids[device.index % world_size]
else:
physical_device_id = device_ids[device.index]
tensor = torch.tensor([physical_device_id],
dtype=torch.int,
device="cpu")
......@@ -267,10 +271,7 @@ class CustomAllreduce:
if envs.VLLM_CUSTOM_CACHE:
return self.all_reduce(input, registered=True)
else:
if not self.fully_connected:
return self.all_reduce(input, registered=False)
else:
return self.all_reduce(input, registered=True)
else:
# If warm up, mimic the allocation pattern since custom
# allreduce is out-of-place.
......
......@@ -73,8 +73,25 @@ class P2pNcclEngine:
self.device = torch.device(f"cuda:{self.local_rank}")
self.nccl = NCCLLibrary(library_path)
if not hostname:
hostname = get_ip()
self.total_num_hidden_layers = getattr(self.model_config.hf_text_config,
"num_hidden_layers", 0)
self.pp_rank = get_pp_group().rank_in_group
self.tp_rank = get_tp_group().rank_in_group
self.pp_size = get_pp_group().world_size
self.tp_size = get_tp_group().world_size
if config.is_kv_producer:
self.remote_tp_size = self.config.get_from_extra_config(
"remote_tp_size", 1)
self.remote_pp_size = self.config.get_from_extra_config(
"remote_pp_size", 1)
self.enable_asymmetric_p2p = self.config.get_from_extra_config(
"enable_asymmetric_p2p", False)
if self.enable_asymmetric_p2p == True:
if self.remote_tp_size % self.tp_size != 0:
logger.error(" the Prefill TP size must be less than or equal to the Decode TP size!!!!")
self.multp = int(self.remote_tp_size / self.tp_size)
self.multiple_machines = self.config.get_from_extra_config(
"enable_multiple_machines", False)
port = int(self.config.kv_port) + port_offset
if port == 0:
raise ValueError("Port cannot be 0")
......@@ -409,6 +426,117 @@ class P2pNcclEngine:
"🚧[PUT_ASYNC]It took %.3fms to wait for the send_queue"
" to be empty, rank:%d", duration * 1000, self.rank)
def _send_kv_p2p_sync(self, tensor_id: str, kv_layer: torch.Tensor,
slot_mapping: torch.Tensor, remote_address: str) -> bool:
if remote_address not in self.socks:
self._create_connect(remote_address)
sock = self.socks[remote_address]
comm, rank = self.comms[remote_address]
is_mla = (kv_layer.ndim == 3)
hidden_dim = kv_layer.shape[-1]
if self.p2p_async_buf is None:
if is_mla:
self.p2p_async_buf = torch.empty((self.p2p_async_kv_tokens, hidden_dim),
dtype=kv_layer.dtype, device=kv_layer.device)
else:
self.p2p_async_buf = torch.empty((2, self.p2p_async_kv_tokens, hidden_dim),
dtype=kv_layer.dtype, device=kv_layer.device)
pack_num = (slot_mapping.shape[0] - 1) // self.p2p_async_kv_tokens + 1
self.tensor_split_num = pack_num
with torch.cuda.stream(self.send_stream):
for pack_idx in range(pack_num):
start = pack_idx * self.p2p_async_kv_tokens
end = min((pack_idx + 1) * self.p2p_async_kv_tokens, slot_mapping.shape[0])
sub_index = slot_mapping[start:end]
if is_mla:
num_pages, page_size = kv_layer.shape[0], kv_layer.shape[1]
data = kv_layer.reshape(num_pages * page_size, -1)
torch.index_select(data, dim=0, index=sub_index, out=self.p2p_async_buf[:end-start])
tx_shape = (end - start, hidden_dim)
else:
num_pages, page_size = kv_layer.shape[1], kv_layer.shape[2]
data = kv_layer.reshape(2, num_pages * page_size, -1)
torch.index_select(data, dim=1, index=sub_index, out=self.p2p_async_buf[:, :end-start])
tx_shape = (2, end - start, hidden_dim)
if is_mla:
send_tensor = self.p2p_async_buf[:end-start]
else:
send_tensor = self.p2p_async_buf[:, :end-start]
header = {
"cmd": "PUT",
"tensor_id": tensor_id + "#" + str(pack_idx), # 拼 pack_idx
"pack_idx": pack_idx,
"tensor_split_num": pack_num,
"shape": tx_shape,
"dtype": str(kv_layer.dtype).replace("torch.", "")
}
sock.send(msgpack.dumps(header))
response = sock.recv()
if response != b"0":
logger.error(
"🔴Send Tensor Failed | %s 👉 %s | Rank:%s | shape:%s | size:%.4f GB | response:%s",
self.zmq_address, remote_address, rank,
tuple(send_tensor.shape), send_tensor.element_size() * send_tensor.numel() / 1024**3,
response.decode()
)
return False
self._send(comm, send_tensor, rank ^ 1, self.send_stream)
if self.send_type == "PUT_ASYNC":
self._have_sent_tensor_id(tensor_id)
return True
def _send_sync_new(
self,
tensor_id: str,
tensor: torch.Tensor,
remote_address: typing.Optional[RemoteAddr] = None,
) -> bool:
if remote_address is None:
return False
if remote_address.zmq_address not in self.socks:
# logger.info(f"""=============xiabo remote_address.zmq_address:{remote_address.zmq_address}""")
self._create_connect_new(remote_address.zmq_address)
sock = self.socks[remote_address.zmq_address]
comm, rank = self.comms[remote_address.pd_pair_id]
data = {
"cmd": "PUT_NEW",
"tensor_id": tensor_id,
"shape": tensor.shape,
"dtype": str(tensor.dtype).replace("torch.", ""),
"pd_pair_id": remote_address.pd_pair_id,
"comm_rank": rank
}
# logger.info(f"""_send_sync_new:{data}""")
sock.send(msgpack.dumps(data))
response = sock.recv()
if response != b"0":
logger.error(
"🔴Send Tensor, Peer Out Of Memory/Threshold, %s 👉 %s, "
"MyRank:%s, data:%s, tensor:%s, size:%fGB, response:%s",
self.zmq_address, remote_address.zmq_address, rank, data, tensor.shape,
tensor.element_size() * tensor.numel() / 1024**3,
response.decode())
return False
self._send(comm, tensor.to(self.device), remote_address.comm_rank, self.send_stream)
if self.send_type == "PUT_ASYNC":
self._have_sent_tensor_id(tensor_id)
return True
def _send_sync(
self,
tensor_id: str,
......@@ -531,3 +659,72 @@ class P2pNcclEngine:
self._send_thread.join()
if self._ping_thread is not None:
self._ping_thread.join()
def get_pp_indices_d(self, num_hidden_layers: int, pp_rank: int,
pp_size: int) -> tuple[int, int]:
partition_list_str = envs.VLLM_PP_LAYER_PARTITION_D
if partition_list_str is not None:
try:
partitions = [
int(layer) for layer in partition_list_str.split(",")
]
except ValueError as err:
raise ValueError("Invalid partition string: {}".format(
partition_list_str)) from err
if len(partitions) != pp_size:
raise ValueError(f"{len(partitions)=} does not match {pp_size=}.")
if sum(partitions) != num_hidden_layers:
raise ValueError(
f"{sum(partitions)=} does not match {num_hidden_layers=}.")
else:
layers_per_partition = num_hidden_layers // pp_size
partitions = [layers_per_partition for _ in range(pp_size)]
if remaining_layers := num_hidden_layers % pp_size:
for i in range(2, remaining_layers + 2):
partitions[-i] += 1
logger.info(
"Hidden layers were unevenly partitioned: [%s]. "
"This can be manually overridden using the "
"VLLM_PP_LAYER_PARTITION_D environment variable",
",".join(str(p) for p in partitions))
start_layer = sum(partitions[:pp_rank])
end_layer = start_layer + partitions[pp_rank]
return (start_layer, end_layer)
def compute_remote_pp_rank(self, layer_name: str) -> int:
current_layer_idx = extract_layer_index(layer_name)
for d_pp_rank in range(self.remote_pp_size):
start, end = self.get_pp_indices_d(self.total_num_hidden_layers, d_pp_rank, self.remote_pp_size)
# logger.info(f"""compute_remote_pp_rank : current_layer_idx:{current_layer_idx} start:{start} end:{end}""")
if (current_layer_idx == self.total_num_hidden_layers):
return self.remote_pp_size - 1
if start <= current_layer_idx < end:
return d_pp_rank
return -1
@staticmethod
def get_tensor_id(request_id: str, layer_name: str) -> str:
return request_id + "#" + layer_name
@staticmethod
def parse_request_id(request_id: str, is_prefill=True) -> tuple[str, int]:
# Regular expression to match the string hostname and integer port
if is_prefill:
pattern = r"___decode_addr_(.*):(\d+)"
else:
pattern = r"___prefill_addr_(.*):(\d+)___"
# Use re.search to find the pattern in the request_id
match = regex.search(pattern, request_id)
if match:
# Extract the ranks
ip = match.group(1)
port = int(match.group(2))
return ip, port
raise ValueError(
f"Request id {request_id} does not contain hostname and port")
......@@ -477,6 +477,9 @@ class EngineArgs:
enable_multimodal_encoder_data_parallel: bool = \
ParallelConfig.enable_multimodal_encoder_data_parallel
enable_dp_attention: bool = \
ParallelConfig.enable_dp_attention
def __post_init__(self):
# support `EngineArgs(compilation_config={...})`
# without having to manually construct a
......@@ -719,6 +722,10 @@ class EngineArgs:
"--enable-multimodal-encoder-data-parallel",
**parallel_kwargs["enable_multimodal_encoder_data_parallel"])
parallel_group.add_argument(
"--enable-dp-attention",
**parallel_kwargs["enable_dp_attention"])
# KV cache arguments
cache_kwargs = get_kwargs(CacheConfig)
cache_group = parser.add_argument_group(
......@@ -1204,6 +1211,7 @@ class EngineArgs:
worker_extension_cls=self.worker_extension_cls,
enable_multimodal_encoder_data_parallel=self.
enable_multimodal_encoder_data_parallel,
enable_dp_attention=self.enable_dp_attention,
)
speculative_config = self.create_speculative_config(
......
......@@ -129,6 +129,8 @@ if TYPE_CHECKING:
VLLM_NIXL_SIDE_CHANNEL_HOST: str = "localhost"
VLLM_NIXL_SIDE_CHANNEL_PORT: int = 5557
VLLM_ALL2ALL_BACKEND: str = "naive"
VLLM_MOE_HT_THRESHOLD: int = 128
VLLM_ALLOW_MNNVL: bool = False
VLLM_MAX_TOKENS_PER_EXPERT_FP4_MOE: int = 163840
VLLM_TOOL_PARSE_REGEX_TIMEOUT_SECONDS: int = 1
VLLM_SLEEP_WHEN_IDLE: bool = False
......@@ -147,6 +149,7 @@ if TYPE_CHECKING:
VLLM_USE_TRITON_PREFIX_FLASH_ATTN: bool = False
VLLM_USE_TRITON_OPT_MLA: bool = False
VLLM_USE_FLASH_ATTN_FP8: bool = False
VLLM_USE_QUERY_QUANT: bool = False
VLLM_USE_FLASH_MLA: bool = False
VLLM_USE_FLASH_MLA_FP8: bool = False
VLLM_USE_OPT_OP: bool = False
......@@ -196,14 +199,25 @@ if TYPE_CHECKING:
VLLM_PP_DEBUG: bool = False
VLLM_USE_V32_ENCODE: bool = False
VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT: bool = False
VLLM_USE_FUSED_CACHE_QUANT_BMM_MLA: bool = False
VLLM_USE_FUSED_RMS_ROPE: bool = False
VLLM_USE_MARLIN_W16A16_MOE:bool = False
VLLM_V1_USE_REDUCED_TOPK_TOPP_SAMPLER: bool = False
VLLM_USE_FUSED_FILL_RMS_CAT:bool = False
VLLM_ENABLE_DEEPEP_HT_DEEPGEMM: bool = True
VLLM_ENABLE_DEEPEP_INT8_DISPATCH: bool = True
VLLM_ZERO_OVERHEAD_ENHANCE: bool = False
VLLM_USE_FUSED_QA_KVA_GEMM: bool = False
VLLM_V1_FAST_TOKEN_ID_COPY: bool = False
VLLM_DISABLE_SHARED_EXPERTS_STREAM:bool = True
VLLM_W8A8_BACKEND: int = 3
VLLM_MOE_ROUTER_CAPTURE: bool = False
VLLM_MOE_ROUTER_CAPTURE_DIR: str = "/tmp"
VLLM_MOE_ROUTER_CAPTURE_RANK: int = -1
VLLM_MOE_ROUTER_CAPTURE_MAX_LAYERS: int = 0
VLLM_MOE_ROUTER_CAPTURE_NUM_TOKENS_GT: int = -1
VLLM_MOE_ROUTER_CAPTURE_NUM_TOKENS_LT: int = -1
VLLM_ENABLE_SHARED_EXPERTS_FUSION: bool = False
VLLM_USE_MOE_W16A16_TRITON: bool = False
def get_default_cache_root():
return os.getenv(
......@@ -980,6 +994,14 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_ALL2ALL_BACKEND":
lambda: os.getenv("VLLM_ALL2ALL_BACKEND", "naive"),
# VLLM_MOE_HT_THRESHOLD
"VLLM_MOE_HT_THRESHOLD":
lambda: int(os.getenv("VLLM_MOE_HT_THRESHOLD", "128")),
# use ALLOW_MNNVL
"VLLM_ALLOW_MNNVL":
lambda: (os.environ.get("VLLM_ALLOW_MNNVL", "False").lower() in
("true", "1")),
# Control the maximum number of tokens per expert supported by the
# NVFP4 MoE CUTLASS Kernel. This value is used to create a buffer for
# the blockscale tensor of activations NVFP4 Quantization.
......@@ -1050,7 +1072,12 @@ environment_variables: dict[str, Callable[[], Any]] = {
# If set, vLLM will use FLASH ATTN fp8 attention optimizations.
"VLLM_USE_FLASH_ATTN_FP8":
lambda: bool(int(os.getenv("VLLM_USE_FLASH_ATTN_FP8", "0"))),
lambda: bool(int(os.getenv("VLLM_USE_FLASH_ATTN_FP8", "1"))),
# flag to control if vllm should use q quant
"VLLM_USE_QUERY_QUANT":
lambda: (os.environ.get("VLLM_USE_QUERY_QUANT", "False").lower() in
("true", "1")),
# If set, vLLM will use FLASH MLA attention optimizations.
"VLLM_USE_FLASH_MLA":
......@@ -1058,7 +1085,7 @@ environment_variables: dict[str, Callable[[], Any]] = {
# If set, vLLM will use FLASH MLA fp8 attention optimizations.
"VLLM_USE_FLASH_MLA_FP8":
lambda: bool(int(os.getenv("VLLM_USE_FLASH_MLA_FP8", "0"))),
lambda: bool(int(os.getenv("VLLM_USE_FLASH_MLA_FP8", "1"))),
# flag to control vllm to use optimized kernels
"VLLM_USE_OPT_OP":
......@@ -1085,7 +1112,7 @@ environment_variables: dict[str, Callable[[], Any]] = {
# flag to control vllm to use optimized kernels
"VLLM_CUSTOM_CACHE":
lambda: bool(int(os.environ.get("VLLM_CUSTOM_CACHE", "0"))),
lambda: bool(int(os.environ.get("VLLM_CUSTOM_CACHE", "1"))),
# flag to control vllm to use optimized kernels
"VLLM_CUSTOM_ALLREDUCE_SUPPORTED_WORLDSIZE_MAX":
......@@ -1211,7 +1238,7 @@ environment_variables: dict[str, Callable[[], Any]] = {
# vLLM will split prefill and decode, not mix up
"VLLM_USE_PD_SPLIT":
lambda: (os.environ.get("VLLM_USE_PD_SPLIT", "False").lower() in
lambda: (os.environ.get("VLLM_USE_PD_SPLIT", "True").lower() in
("true", "1")),
# vLLM will sync to avoid pp vmfault
......@@ -1283,14 +1310,15 @@ environment_variables: dict[str, Callable[[], Any]] = {
lambda: (os.getenv('VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT', 'False').lower() in
("true", "1")),
# vllm will use fused rmsnorm + contiguous + rope(for dpsk-v3) + concat_and_cache_mla + q quant, control bmm + cat +mla (fp8)
"VLLM_USE_FUSED_CACHE_QUANT_BMM_MLA":
lambda: (os.getenv('VLLM_USE_FUSED_CACHE_QUANT_BMM_MLA', 'False').lower() in
("true", "1")),
# vLLM will use fused RMS + RoPE kernel
"VLLM_USE_FUSED_RMS_ROPE":
lambda: (os.environ.get("VLLM_USE_FUSED_RMS_ROPE", "False").lower() in
("true", "1")),
# vLLM will use Marlin W16A16 kernel for MoE experts
"VLLM_USE_MARLIN_W16A16_MOE":
lambda: (os.environ.get("VLLM_USE_MARLIN_W16A16_MOE", "False").lower() in
("true", "1")),
# vLLM will use lightop for dpsk mtp fill + rms*2 + cat
"VLLM_USE_FUSED_FILL_RMS_CAT":
lambda: (os.environ.get("VLLM_USE_FUSED_FILL_RMS_CAT", "False").lower() in
......@@ -1306,6 +1334,12 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_ENABLE_DEEPEP_HT_DEEPGEMM":
lambda: (os.getenv('VLLM_ENABLE_DEEPEP_HT_DEEPGEMM', '1').lower() in
("true", "1")),
# vLLM will use deepep int8 dispatch
"VLLM_ENABLE_DEEPEP_INT8_DISPATCH":
lambda: (os.getenv('VLLM_ENABLE_DEEPEP_INT8_DISPATCH', '1').lower() in
("true", "1")),
# Only quantized DeepSeek models supported.
# Unquantized versions are not supported.
"VLLM_USE_FUSED_QA_KVA_GEMM":
......@@ -1318,6 +1352,52 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_V1_FAST_TOKEN_ID_COPY":
lambda: (os.environ.get("VLLM_V1_FAST_TOKEN_ID_COPY", "False").lower() in
("true", "1")),
# shared experts overlap with routed experts
# VLLM_DISABLE_SHARED_EXPERTS_STREAM = 1 disable shared experts overlap
# VLLM_DISABLE_SHARED_EXPERTS_STREAM = 0 enable shared experts overlap
"VLLM_DISABLE_SHARED_EXPERTS_STREAM": lambda: bool(
int(os.getenv("VLLM_DISABLE_SHARED_EXPERTS_STREAM", "1"))
),
# shared experts fusion
# VLLM_ENABLE_SHARED_EXPERTS_FUSION = 1 enable shared experts fusion
# VLLM_ENABLE_SHARED_EXPERTS_FUSION = 0 disable shared experts fusion
"VLLM_ENABLE_SHARED_EXPERTS_FUSION": lambda: bool(
int(os.getenv("VLLM_ENABLE_SHARED_EXPERTS_FUSION", "0"))
),
# W8A8 GEMM backend selection for vLLM quantized models.
# lightop/triton: 1
# cutlass: 2 (will remove in the future)
# blaslt: 3 (default)
# rocblas: others
"VLLM_W8A8_BACKEND": lambda: int(os.getenv("VLLM_W8A8_BACKEND", "3")),
# Capture MoE router logits for debugging/analysis.
"VLLM_MOE_ROUTER_CAPTURE":
lambda: (os.getenv("VLLM_MOE_ROUTER_CAPTURE", "0").lower() in ("true", "1")),
# Output directory for MoE router capture dumps.
"VLLM_MOE_ROUTER_CAPTURE_DIR":
lambda: os.environ.get(
"VLLM_MOE_ROUTER_CAPTURE_DIR",
"/tmp",
),
# Capture only the specified rank; set to -1 to capture all ranks.
"VLLM_MOE_ROUTER_CAPTURE_RANK":
lambda: int(os.environ.get("VLLM_MOE_ROUTER_CAPTURE_RANK", "-1")),
# Max number of MoE layers to record per process (0 = unlimited).
"VLLM_MOE_ROUTER_CAPTURE_MAX_LAYERS":
lambda: int(os.environ.get("VLLM_MOE_ROUTER_CAPTURE_MAX_LAYERS", "0")),
# Only capture when num_tokens > N (negative disables).
"VLLM_MOE_ROUTER_CAPTURE_NUM_TOKENS_GT":
lambda: int(os.environ.get("VLLM_MOE_ROUTER_CAPTURE_NUM_TOKENS_GT", "-1")),
# Only capture when num_tokens < N (0 disables).
"VLLM_MOE_ROUTER_CAPTURE_NUM_TOKENS_LT":
lambda: int(os.environ.get("VLLM_MOE_ROUTER_CAPTURE_NUM_TOKENS_LT", "-1")),
# Force using Triton MoE path (disable Marlin W16A16 MoE).
"VLLM_USE_MOE_W16A16_TRITON":
lambda: (os.environ.get("VLLM_USE_MOE_W16A16_TRITON", "0").lower() in
("true", "1")),
}
# --8<-- [end:env-vars-definition]
......@@ -1382,6 +1462,7 @@ def compute_hash() -> str:
"VLLM_DP_SIZE",
"VLLM_USE_STANDALONE_COMPILE",
"VLLM_FUSED_MOE_CHUNK_SIZE",
"VLLM_W8A8_BACKEND",
]
for key in environment_variables_to_hash:
if key in environment_variables:
......
......@@ -136,8 +136,8 @@ def set_forward_context(
forward_start_time = time.perf_counter()
dp_metadata: Optional[DPMetadata] = None
dp_size = vllm_config.parallel_config.data_parallel_size
use_navie_ep = envs.VLLM_ALL2ALL_BACKEND == 'naive' and dp_size > 1 and vllm_config.parallel_config.enable_expert_parallel
if use_navie_ep and dp_size > 1 and (
use_navie_all2all = envs.VLLM_ALL2ALL_BACKEND == 'naive' and dp_size > 1
if use_navie_all2all and dp_size > 1 and (
attn_metadata is not None or num_tokens is not None):
dp_metadata = DPMetadata.make(vllm_config.parallel_config,
attn_metadata, num_tokens or 0,
......@@ -211,3 +211,14 @@ def set_profilling(profiling):
def get_profilling() -> bool:
global _profiling
return _profiling
_warming_up = False
@contextmanager
def set_warming_up(warming_up):
global _warming_up
_warming_up = warming_up
def get_warming_up() -> bool:
global _warming_up
return _warming_up
\ No newline at end of file
from typing import TYPE_CHECKING, List, Optional, Tuple
import logging
import torch
import vllm.envs as envs
from vllm.distributed.parallel_state import GroupCoordinator, init_model_parallel_group, get_world_group
from vllm.distributed import (get_ep_group, get_pp_group, get_dp_group,
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_gather,
get_tensor_model_parallel_rank,
tensor_model_parallel_reduce_scatter,
get_tp_group)
_ENABLE_DP_ATTENTION_FLAG: bool = False
_MOE_TP: Optional[GroupCoordinator] = None
_ATTN_DP_SIZE = 0
_ATTN_TP_SIZE = 0
_ATTN_TP_RANK = 0
_ATTN_DP_RANK = 0
_MOT_TP_SIZE = 0
_MOT_TP_RANK = 0
def initialize_dp_attention(vllm_config, backend: Optional[str] = None):
from vllm.config import VllmConfig
assert isinstance(vllm_config, VllmConfig)
global _ENABLE_DP_ATTENTION_FLAG, _ATTN_DP_SIZE, _ATTN_TP_SIZE, _ATTN_TP_RANK, _ATTN_DP_RANK, _MOT_TP_SIZE, _MOT_TP_RANK
enable_dp_attention = vllm_config.parallel_config.enable_dp_attention
_ENABLE_DP_ATTENTION_FLAG = enable_dp_attention
# Build the moe tensor model-parallel groups.
world_size: int = torch.distributed.get_world_size()
rank = torch.distributed.get_rank()
data_parallel_size = vllm_config.parallel_config.data_parallel_size
pipeline_model_parallel_size = vllm_config.parallel_config.pipeline_parallel_size
tensor_model_parallel_size = vllm_config.parallel_config.tensor_parallel_size
moe_tp_size = world_size // pipeline_model_parallel_size
moe_ep_size = moe_tp_size if vllm_config.parallel_config.enable_expert_parallel else 1
_ATTN_DP_SIZE = data_parallel_size
_ATTN_TP_SIZE = tensor_model_parallel_size
_ATTN_TP_RANK = get_tensor_model_parallel_rank()
_ATTN_DP_RANK = vllm_config.parallel_config.data_parallel_rank
_MOT_TP_SIZE = moe_tp_size
_MOT_TP_RANK = rank % _MOT_TP_SIZE
global _MOE_TP
assert _MOE_TP is None, ("moe tensor model parallel group is already initialized")
backend = backend or torch.distributed.get_backend(
get_world_group().device_group)
group_ranks = []
for i in range(pipeline_model_parallel_size):
ranks = list(
range(i * moe_tp_size, (i + 1) * moe_tp_size)
)
group_ranks.append(ranks)
# message queue broadcaster is only used in tensor model parallel group
_MOE_TP = init_model_parallel_group(group_ranks,
get_world_group().local_rank,
backend,
use_message_queue_broadcaster=True,
group_name="moe_tp")
def get_attention_tp_size() -> int:
assert _ATTN_TP_SIZE is not None, "dp attention not initialized!"
return _ATTN_TP_SIZE
def get_attention_tp_rank() -> int:
assert _ATTN_TP_RANK is not None, "dp attention not initialized!"
return _ATTN_TP_RANK
def get_moe_tp_group() -> GroupCoordinator:
assert _MOE_TP is not None, ("tensor model parallel group is not initialized")
return _MOE_TP
def get_attention_dp_size() -> int:
assert _ATTN_DP_SIZE is not None, "dp attention not initialized!"
return _ATTN_DP_SIZE
def get_moe_tp_rank() -> int:
assert _MOT_TP_RANK is not None, "dp attention not initialized!"
return _MOT_TP_RANK
def get_moe_tp_size() -> int:
assert _MOT_TP_SIZE is not None, "dp attention not initialized!"
return _MOT_TP_SIZE
def get_attention_tp_group() -> GroupCoordinator:
return get_tp_group()
def moe_tensor_model_parallel_all_gather(input_: torch.Tensor,
dim: int = -1) -> torch.Tensor:
"""All-gather the input tensor across model parallel group."""
return get_moe_tp_group().all_gather(input_, dim)
def moe_tensor_model_parallel_reduce_scatter(input_: torch.Tensor,
dim: int = -1) -> torch.Tensor:
"""Reduce-Scatter the input tensor across model parallel group."""
return get_moe_tp_group().reduce_scatter(input_, dim)
def dp_gather(
hidden_states: torch.Tensor,)-> torch.Tensor:
if get_attention_tp_size() == 1:
hidden_states = moe_tensor_model_parallel_all_gather(hidden_states, dim=0)
return hidden_states
hidden_states = tensor_model_parallel_reduce_scatter(hidden_states, dim=0)
hidden_states = moe_tensor_model_parallel_all_gather(hidden_states, dim=0)
return hidden_states
def dp_reduce_scatter_tensor(hidden_states: torch.Tensor)-> torch.Tensor:
if get_moe_tp_group().world_size == get_attention_dp_size():
hidden_states = moe_tensor_model_parallel_reduce_scatter(hidden_states, dim=0)
else:
hidden_states = moe_tensor_model_parallel_reduce_scatter(hidden_states, dim=0)
hidden_states = tensor_model_parallel_all_gather(hidden_states, dim=0)
return hidden_states
......@@ -187,6 +187,11 @@ class FusedMoEParallelConfig:
return (self.use_all2all_kernels
and envs.VLLM_ALL2ALL_BACKEND == "deepep_low_latency")
@property
def use_deepep_auto_kernels(self):
return (self.use_all2all_kernels
and envs.VLLM_ALL2ALL_BACKEND == "deepep_auto")
@staticmethod
def make(tp_size_: int, dp_size_: int,
vllm_parallel_config: ParallelConfig) -> "FusedMoEParallelConfig":
......@@ -385,6 +390,10 @@ class FusedMoEConfig:
def use_deepep_ll_kernels(self):
return self.moe_parallel_config.use_deepep_ll_kernels
@property
def use_deepep_auto_kernels(self):
return self.moe_parallel_config.use_deepep_auto_kernels
@staticmethod
def make(
num_experts: int,
......
{
"1": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 2,
"num_stages": 2
},
"2": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2
},
"4": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2
},
"8": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2
},
"16": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 1,
"num_warps": 2,
"num_stages": 2
},
"24": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 1,
"num_warps": 2,
"num_stages": 2
},
"32": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 1,
"num_warps": 2,
"num_stages": 2
},
"48": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2
},
"64": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 2,
"num_stages": 2
},
"96": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 1,
"num_warps": 2,
"num_stages": 2
},
"128": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 1,
"num_warps": 2,
"num_stages": 2
},
"256": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 1,
"num_warps": 2,
"num_stages": 2
},
"512": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 2
},
"1024": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2
},
"1536": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2
},
"2048": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2
},
"3072": {
"BLOCK_SIZE_M": 256,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2
},
"4096": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2
}
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment