Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
dcb5624a
Commit
dcb5624a
authored
Apr 29, 2025
by
zhuwenwen
Browse files
Merge tag 'v0.8.5' into v0.8.5-dev
parents
55880ca2
ba41cc90
Changes
554
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
2026 additions
and
17 deletions
+2026
-17
csrc/ops.h
csrc/ops.h
+6
-0
csrc/quantization/cutlass_w8a8/moe/moe_data.cu
csrc/quantization/cutlass_w8a8/moe/moe_data.cu
+15
-2
csrc/quantization/cutlass_w8a8/scaled_mm_c2x_sm89_fp8_dispatch.cuh
...tization/cutlass_w8a8/scaled_mm_c2x_sm89_fp8_dispatch.cuh
+1
-1
csrc/quantization/cutlass_w8a8/scaled_mm_c2x_sm89_int8_dispatch.cuh
...ization/cutlass_w8a8/scaled_mm_c2x_sm89_int8_dispatch.cuh
+1
-1
csrc/quantization/fp4/nvfp4_scaled_mm_kernels.cu
csrc/quantization/fp4/nvfp4_scaled_mm_kernels.cu
+1
-1
csrc/quantization/gptq_marlin/marlin.cuh
csrc/quantization/gptq_marlin/marlin.cuh
+7
-2
csrc/quantization/gptq_marlin/marlin_dtypes.cuh
csrc/quantization/gptq_marlin/marlin_dtypes.cuh
+7
-3
csrc/rocm/ops.h
csrc/rocm/ops.h
+9
-0
csrc/rocm/skinny_gemms.cu
csrc/rocm/skinny_gemms.cu
+1600
-0
csrc/rocm/torch_bindings.cpp
csrc/rocm/torch_bindings.cpp
+18
-0
csrc/torch_bindings.cpp
csrc/torch_bindings.cpp
+7
-0
docker/Dockerfile
docker/Dockerfile
+9
-0
docker/Dockerfile.cpu
docker/Dockerfile.cpu
+1
-0
docker/Dockerfile.nightly_torch
docker/Dockerfile.nightly_torch
+307
-0
docker/Dockerfile.ppc64le
docker/Dockerfile.ppc64le
+14
-4
docker/Dockerfile.rocm_base
docker/Dockerfile.rocm_base
+1
-1
docker/Dockerfile.s390x
docker/Dockerfile.s390x
+22
-2
docs/source/assets/deployment/anything-llm-chat-with-doc.png
docs/source/assets/deployment/anything-llm-chat-with-doc.png
+0
-0
docs/source/assets/deployment/anything-llm-chat-without-doc.png
...ource/assets/deployment/anything-llm-chat-without-doc.png
+0
-0
docs/source/assets/deployment/anything-llm-provider.png
docs/source/assets/deployment/anything-llm-provider.png
+0
-0
No files found.
Too many changes to show.
To preserve performance only
554 of 554+
files are displayed.
Plain diff
Email patch
csrc/ops.h
View file @
dcb5624a
...
@@ -269,6 +269,12 @@ void advance_step_flashinfer(
...
@@ -269,6 +269,12 @@ void advance_step_flashinfer(
torch
::
Tensor
&
paged_kv_indices
,
torch
::
Tensor
&
paged_kv_indptr
,
torch
::
Tensor
&
paged_kv_indices
,
torch
::
Tensor
&
paged_kv_indptr
,
torch
::
Tensor
&
paged_kv_last_page_len
,
torch
::
Tensor
&
block_table_bounds
);
torch
::
Tensor
&
paged_kv_last_page_len
,
torch
::
Tensor
&
block_table_bounds
);
// void cutlass_mla_decode(torch::Tensor const& out, torch::Tensor const& q_nope,
// torch::Tensor const& q_pe,
// torch::Tensor const& kv_c_and_k_pe_cache,
// torch::Tensor const& seq_lens,
// torch::Tensor const& page_table, double scale);
torch
::
Tensor
get_cuda_view_from_cpu_tensor
(
torch
::
Tensor
&
cpu_tensor
);
torch
::
Tensor
get_cuda_view_from_cpu_tensor
(
torch
::
Tensor
&
cpu_tensor
);
#ifndef USE_ROCM
#ifndef USE_ROCM
...
...
csrc/quantization/cutlass_w8a8/moe/moe_data.cu
View file @
dcb5624a
...
@@ -46,14 +46,26 @@ __global__ void compute_expert_offsets(
...
@@ -46,14 +46,26 @@ __global__ void compute_expert_offsets(
}
}
__global__
void
compute_arg_sorts
(
const
int
*
__restrict__
topk_ids
,
__global__
void
compute_arg_sorts
(
const
int
*
__restrict__
topk_ids
,
const
int32_t
*
__restrict__
expert_offsets
,
int32_t
*
input_permutation
,
int32_t
*
input_permutation
,
int32_t
*
output_permutation
,
int32_t
*
output_permutation
,
int32_t
*
atomic_buffer
,
const
int
topk_length
,
int32_t
*
atomic_buffer
,
const
int
topk_length
,
const
int
topk
)
{
const
int
topk
)
{
int
expert_id
=
blockIdx
.
x
;
int
const
blk_expert_id
=
blockIdx
.
x
;
int
const
num_experts
=
gridDim
.
x
;
int32_t
const
num_tokens
=
expert_offsets
[
num_experts
];
for
(
int
i
=
threadIdx
.
x
;
i
<
topk_length
;
i
+=
THREADS_PER_EXPERT
)
{
for
(
int
i
=
threadIdx
.
x
;
i
<
topk_length
;
i
+=
THREADS_PER_EXPERT
)
{
if
(
topk_ids
[
i
]
==
expert_id
)
{
int
const
expert_id
=
topk_ids
[
i
];
if
(
expert_id
==
-
1
&&
blockIdx
.
x
==
0
)
{
// output_permutation is used to re-order the moe outputs. It is
// used as c2 = c2[c_map], where c2 is a torch.tensor that is the
// output of the cutlass kernels and c_map is the output_permutation.
// c2 is initialized to zeros, therefore by setting the output_permutation
// to num_tokens, we are guaranteed to fill the moe outputs to zero
// for "invalid" topk_ids.
output_permutation
[
i
]
=
num_tokens
;
}
else
if
(
expert_id
==
blk_expert_id
)
{
int
start
=
atomicAdd
(
&
atomic_buffer
[
expert_id
],
1
);
int
start
=
atomicAdd
(
&
atomic_buffer
[
expert_id
],
1
);
input_permutation
[
start
]
=
i
/
topk
;
input_permutation
[
start
]
=
i
/
topk
;
output_permutation
[
i
]
=
start
;
output_permutation
[
i
]
=
start
;
...
@@ -83,6 +95,7 @@ void get_cutlass_moe_mm_data_caller(
...
@@ -83,6 +95,7 @@ void get_cutlass_moe_mm_data_caller(
static_cast
<
int32_t
*>
(
atomic_buffer
.
data_ptr
()),
num_experts
);
static_cast
<
int32_t
*>
(
atomic_buffer
.
data_ptr
()),
num_experts
);
compute_arg_sorts
<<<
num_experts
,
num_threads
,
0
,
stream
>>>
(
compute_arg_sorts
<<<
num_experts
,
num_threads
,
0
,
stream
>>>
(
static_cast
<
const
int32_t
*>
(
topk_ids
.
data_ptr
()),
static_cast
<
const
int32_t
*>
(
topk_ids
.
data_ptr
()),
static_cast
<
const
int32_t
*>
(
expert_offsets
.
data_ptr
()),
static_cast
<
int32_t
*>
(
input_permutation
.
data_ptr
()),
static_cast
<
int32_t
*>
(
input_permutation
.
data_ptr
()),
static_cast
<
int32_t
*>
(
output_permutation
.
data_ptr
()),
static_cast
<
int32_t
*>
(
output_permutation
.
data_ptr
()),
static_cast
<
int32_t
*>
(
atomic_buffer
.
data_ptr
()),
topk_ids
.
numel
(),
static_cast
<
int32_t
*>
(
atomic_buffer
.
data_ptr
()),
topk_ids
.
numel
(),
...
...
csrc/quantization/cutlass_w8a8/scaled_mm_c2x_sm89_fp8_dispatch.cuh
View file @
dcb5624a
...
@@ -336,7 +336,7 @@ inline void cutlass_gemm_sm89_fp8_dispatch(torch::Tensor& out,
...
@@ -336,7 +336,7 @@ inline void cutlass_gemm_sm89_fp8_dispatch(torch::Tensor& out,
uint32_t
const
m
=
a
.
size
(
0
);
uint32_t
const
m
=
a
.
size
(
0
);
uint32_t
const
mp2
=
uint32_t
const
mp2
=
std
::
max
(
static_cast
<
uint32_t
>
(
32
),
next_pow_2
(
m
));
// next power of 2
std
::
max
(
static_cast
<
uint32_t
>
(
16
),
next_pow_2
(
m
));
// next power of 2
if
(
mp2
<=
16
)
{
if
(
mp2
<=
16
)
{
// M in [1, 16]
// M in [1, 16]
...
...
csrc/quantization/cutlass_w8a8/scaled_mm_c2x_sm89_int8_dispatch.cuh
View file @
dcb5624a
...
@@ -321,7 +321,7 @@ inline void cutlass_gemm_sm89_int8_dispatch(torch::Tensor& out,
...
@@ -321,7 +321,7 @@ inline void cutlass_gemm_sm89_int8_dispatch(torch::Tensor& out,
uint32_t
const
m
=
a
.
size
(
0
);
uint32_t
const
m
=
a
.
size
(
0
);
uint32_t
const
mp2
=
uint32_t
const
mp2
=
std
::
max
(
static_cast
<
uint32_t
>
(
32
),
next_pow_2
(
m
));
// next power of 2
std
::
max
(
static_cast
<
uint32_t
>
(
16
),
next_pow_2
(
m
));
// next power of 2
if
(
mp2
<=
16
)
{
if
(
mp2
<=
16
)
{
// M in [1, 16]
// M in [1, 16]
...
...
csrc/quantization/fp4/nvfp4_scaled_mm_kernels.cu
View file @
dcb5624a
...
@@ -134,7 +134,7 @@ typename T::Gemm::Arguments args_from_options(
...
@@ -134,7 +134,7 @@ typename T::Gemm::Arguments args_from_options(
using
StrideB
=
typename
T
::
StrideB
;
using
StrideB
=
typename
T
::
StrideB
;
using
StrideD
=
typename
T
::
StrideD
;
using
StrideD
=
typename
T
::
StrideD
;
using
Sm100BlkScaledConfig
=
using
Sm100BlkScaledConfig
=
typename
T
::
Gemm
::
GemmKernel
::
CollectiveMainloop
::
Sm1
00
BlkScaledConfig
;
typename
T
::
Gemm
::
GemmKernel
::
CollectiveMainloop
::
Sm1
xx
BlkScaledConfig
;
int
m
=
static_cast
<
int
>
(
M
);
int
m
=
static_cast
<
int
>
(
M
);
int
n
=
static_cast
<
int
>
(
N
);
int
n
=
static_cast
<
int
>
(
N
);
...
...
csrc/quantization/gptq_marlin/marlin.cuh
View file @
dcb5624a
...
@@ -9,7 +9,11 @@
...
@@ -9,7 +9,11 @@
#include <cuda_runtime.h>
#include <cuda_runtime.h>
#include <iostream>
#include <iostream>
namespace
marlin
{
#ifndef MARLIN_NAMESPACE_NAME
#define MARLIN_NAMESPACE_NAME marlin
#endif
namespace
MARLIN_NAMESPACE_NAME
{
// Marlin params
// Marlin params
...
@@ -23,6 +27,7 @@ static constexpr int pipe_stages =
...
@@ -23,6 +27,7 @@ static constexpr int pipe_stages =
static
constexpr
int
min_thread_n
=
64
;
static
constexpr
int
min_thread_n
=
64
;
static
constexpr
int
min_thread_k
=
64
;
static
constexpr
int
min_thread_k
=
64
;
static
constexpr
int
max_thread_n
=
256
;
static
constexpr
int
tile_size
=
16
;
static
constexpr
int
tile_size
=
16
;
static
constexpr
int
max_par
=
16
;
static
constexpr
int
max_par
=
16
;
...
@@ -84,4 +89,4 @@ __device__ inline void cp_async_wait() {
...
@@ -84,4 +89,4 @@ __device__ inline void cp_async_wait() {
#endif
#endif
}
// namespace
marlin
}
// namespace
MARLIN_NAMESPACE_NAME
csrc/quantization/gptq_marlin/marlin_dtypes.cuh
View file @
dcb5624a
...
@@ -5,7 +5,11 @@
...
@@ -5,7 +5,11 @@
#include <cuda_fp16.h>
#include <cuda_fp16.h>
#include <cuda_bf16.h>
#include <cuda_bf16.h>
namespace
marlin
{
#ifndef MARLIN_NAMESPACE_NAME
#define MARLIN_NAMESPACE_NAME marlin
#endif
namespace
MARLIN_NAMESPACE_NAME
{
template
<
typename
scalar_t
>
template
<
typename
scalar_t
>
class
ScalarType
{};
class
ScalarType
{};
...
@@ -54,7 +58,7 @@ class ScalarType<nv_bfloat16> {
...
@@ -54,7 +58,7 @@ class ScalarType<nv_bfloat16> {
using
FragS
=
Vec
<
nv_bfloat162
,
1
>
;
using
FragS
=
Vec
<
nv_bfloat162
,
1
>
;
using
FragZP
=
Vec
<
nv_bfloat162
,
4
>
;
using
FragZP
=
Vec
<
nv_bfloat162
,
4
>
;
#if defined(__CUDA_ARCH__)
&&
__CUDA_ARCH__ >= 800
#if
!
defined(__CUDA_ARCH__)
||
__CUDA_ARCH__ >= 800
static
__device__
float
inline
num2float
(
const
nv_bfloat16
x
)
{
static
__device__
float
inline
num2float
(
const
nv_bfloat16
x
)
{
return
__bfloat162float
(
x
);
return
__bfloat162float
(
x
);
}
}
...
@@ -74,6 +78,6 @@ class ScalarType<nv_bfloat16> {
...
@@ -74,6 +78,6 @@ class ScalarType<nv_bfloat16> {
#endif
#endif
};
};
}
// namespace
marlin
}
// namespace
MARLIN_NAMESPACE_NAME
#endif
#endif
csrc/rocm/ops.h
View file @
dcb5624a
...
@@ -2,6 +2,15 @@
...
@@ -2,6 +2,15 @@
#include <torch/all.h>
#include <torch/all.h>
torch
::
Tensor
LLMM1
(
at
::
Tensor
&
in_a
,
at
::
Tensor
&
in_b
,
const
int64_t
rows_per_block
);
torch
::
Tensor
wvSplitK
(
at
::
Tensor
&
in_a
,
at
::
Tensor
&
in_b
,
const
int64_t
CuCount
);
void
wvSplitKQ
(
at
::
Tensor
&
in_a
,
at
::
Tensor
&
in_b
,
at
::
Tensor
&
out_c
,
at
::
Tensor
&
scale_a
,
at
::
Tensor
&
scale_b
,
const
int64_t
CuCount
);
void
paged_attention
(
torch
::
Tensor
&
out
,
torch
::
Tensor
&
exp_sums
,
void
paged_attention
(
torch
::
Tensor
&
out
,
torch
::
Tensor
&
exp_sums
,
torch
::
Tensor
&
max_logits
,
torch
::
Tensor
&
tmp_out
,
torch
::
Tensor
&
max_logits
,
torch
::
Tensor
&
tmp_out
,
torch
::
Tensor
&
query
,
torch
::
Tensor
&
key_cache
,
torch
::
Tensor
&
query
,
torch
::
Tensor
&
key_cache
,
...
...
csrc/rocm/skinny_gemms.cu
0 → 100644
View file @
dcb5624a
#include <torch/all.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <cuda_runtime.h>
#include <cuda_fp16.h>
#include <cuda_bf16.h>
#include <stdexcept>
#include <algorithm>
#include "cuda_compat.h"
#include "dispatch_utils.h"
#include "quantization/fp8/common.cuh"
#if defined(__HIPCC__) && (defined(__gfx90a__) || defined(__gfx942__))
#define __HIP__MI300_MI250__
#endif
#if defined(__HIPCC__) && defined(__gfx942__)
#define __HIP__MI300__
#endif
#if defined(NDEBUG)
#undef NDEBUG
#include <assert.h>
#define UNREACHABLE_CODE assert(false);
#define NDEBUG
#else
#define UNREACHABLE_CODE assert(false);
#endif
template
<
typename
T
>
struct
scalar
{};
template
<
typename
T
>
struct
scalar2
{};
template
<
typename
T
>
__device__
__forceinline__
float2
__s22float2
(
T
v
);
template
<
typename
T
>
__device__
__forceinline__
T
__float2s
(
float
v
);
template
<
typename
T
>
__device__
__forceinline__
T
__float22s2_rn
(
float2
v
);
// Definitions and cvt functions for fp16
template
<
>
struct
scalar
<
c10
::
Half
>
{
using
type
=
half
;
};
template
<
>
struct
scalar2
<
c10
::
Half
>
{
using
type
=
__half2
;
};
template
<
>
__device__
__forceinline__
half
__float2s
(
float
v
)
{
return
__float2half
(
v
);
}
template
<
>
__device__
__forceinline__
float2
__s22float2
(
__half2
v
)
{
return
__half22float2
(
v
);
}
template
<
>
__device__
__forceinline__
__half2
__float22s2_rn
(
float2
v
)
{
return
__float22half2_rn
(
v
);
}
// Definitions and cvt functions for bf16
template
<
>
struct
scalar
<
c10
::
BFloat16
>
{
using
type
=
__hip_bfloat16
;
};
template
<
>
struct
scalar2
<
c10
::
BFloat16
>
{
using
type
=
__hip_bfloat162
;
};
template
<
>
__device__
__forceinline__
__hip_bfloat16
__float2s
(
float
v
)
{
return
__float2bfloat16
(
v
);
}
template
<
>
__device__
__forceinline__
float2
__s22float2
(
__hip_bfloat162
v
)
{
return
__bfloat1622float2
(
v
);
}
template
<
>
__device__
__forceinline__
__hip_bfloat162
__float22s2_rn
(
float2
v
)
{
return
__float22bfloat162_rn
(
v
);
}
template
<
typename
T
>
__device__
__forceinline__
T
loadnt
(
T
*
addr
)
{
return
__builtin_nontemporal_load
(
addr
);
}
__device__
__forceinline__
float4
load_ntmprl
(
const
float4
*
addr
)
{
auto
addr_alias
=
reinterpret_cast
<
const
float
*>
(
addr
);
auto
dat0
=
loadnt
(
addr_alias
);
auto
dat1
=
loadnt
(
addr_alias
+
1
);
auto
dat2
=
loadnt
(
addr_alias
+
2
);
auto
dat3
=
loadnt
(
addr_alias
+
3
);
return
make_float4
(
dat0
,
dat1
,
dat2
,
dat3
);
}
// TBlock fetches entire rows of A, and entire col of B (K dimension); assume
// N=1 for time being grid is M/A_NUM_ROWS blocks
template
<
typename
scalar_t
,
int
NUM_A_ROWS_PER_BLOCK
>
__global__
void
LLGemm1_kernel
(
const
scalar_t
*
in_a
,
const
scalar_t
*
in_b
,
scalar_t
*
out_c
,
const
int
K
)
{
using
scalar2_t
=
typename
scalar2
<
scalar_t
>::
type
;
auto
af4
=
reinterpret_cast
<
const
float4
*>
(
in_a
);
auto
bf4
=
reinterpret_cast
<
const
scalar2_t
*>
(
in_b
);
auto
c
=
reinterpret_cast
<
scalar2_t
*>
(
out_c
);
__shared__
float
red_smem
[
NUM_A_ROWS_PER_BLOCK
][
WARP_SIZE
];
const
int
row_addr
=
blockIdx
.
x
*
NUM_A_ROWS_PER_BLOCK
*
K
/
8
;
const
int
threadid
=
threadIdx
.
x
;
const
int
warp
=
threadIdx
.
x
/
WARP_SIZE
;
const
int
lane
=
threadIdx
.
x
%
WARP_SIZE
;
const
int
num_warps
=
blockDim
.
x
/
WARP_SIZE
;
const
int
qwarpid
=
threadid
/
num_warps
;
const
int
qthreadid
=
threadid
%
num_warps
;
float4
rowA_elem4
[
NUM_A_ROWS_PER_BLOCK
];
scalar2_t
colB_elem4x
,
colB_elem4y
,
colB_elem4z
,
colB_elem4w
;
float
acc
[
NUM_A_ROWS_PER_BLOCK
];
scalar2_t
acch2
;
scalar2_t
oval
;
// As we later use warp shuffle operations, we may have more threads in the
// block than the actual available data, hence the if guard here.
if
(
threadid
*
8
<
K
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
NUM_A_ROWS_PER_BLOCK
;
i
++
)
{
// rowA_elem4[i] holds 8 * half numbers seen as a single float4.
rowA_elem4
[
i
]
=
load_ntmprl
(
&
af4
[
row_addr
+
threadid
+
K
/
8
*
i
]);
}
}
colB_elem4x
=
bf4
[
threadid
*
4
+
0
];
colB_elem4y
=
bf4
[
threadid
*
4
+
1
];
colB_elem4z
=
bf4
[
threadid
*
4
+
2
];
colB_elem4w
=
bf4
[
threadid
*
4
+
3
];
scalar2_t
Af2
;
[[
maybe_unused
]]
scalar2_t
Bf2
;
float2
S
;
auto
Ah2ptr
=
reinterpret_cast
<
scalar2_t
*>
(
&
rowA_elem4
);
scalar2_t
*
ah2lptr
;
#pragma unroll
for
(
int
i
=
0
;
i
<
NUM_A_ROWS_PER_BLOCK
;
i
++
)
{
// Multiply-add on 8 scalar_t.
ah2lptr
=
Ah2ptr
+
i
*
4
;
Af2
=
*
(
ah2lptr
);
acch2
=
__hmul2
(
Af2
,
colB_elem4x
);
Af2
=
*
(
ah2lptr
+
1
);
acch2
=
__hfma2
(
Af2
,
colB_elem4y
,
acch2
);
Af2
=
*
(
ah2lptr
+
2
);
acch2
=
__hfma2
(
Af2
,
colB_elem4z
,
acch2
);
Af2
=
*
(
ah2lptr
+
3
);
acch2
=
__hfma2
(
Af2
,
colB_elem4w
,
acch2
);
S
=
__s22float2
(
acch2
);
// See comment above concerning the if guard.
acc
[
i
]
=
(
threadid
*
8
<
K
?
S
.
x
+
S
.
y
:
0.
f
);
}
// all reduce across warp.
#pragma unroll
for
(
int
mask
=
WARP_SIZE
/
2
;
mask
>=
1
;
mask
/=
2
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
NUM_A_ROWS_PER_BLOCK
;
i
++
)
{
acc
[
i
]
+=
__shfl_xor
(
acc
[
i
],
mask
);
}
}
// Warp leaders store the data to shared memory.
if
(
lane
<
NUM_A_ROWS_PER_BLOCK
)
{
red_smem
[
lane
][
warp
]
=
acc
[
lane
];
}
// Make sure the data is in shared memory.
__syncthreads
();
if
(
qwarpid
<
NUM_A_ROWS_PER_BLOCK
)
{
acc
[
qwarpid
]
=
qthreadid
<
num_warps
?
red_smem
[
qwarpid
][
qthreadid
]
:
0.
f
;
for
(
int
mask
=
num_warps
/
2
;
mask
>=
1
;
mask
/=
2
)
{
acc
[
qwarpid
]
+=
__shfl_xor
(
acc
[
qwarpid
],
mask
);
}
float
oval2
=
__shfl_xor
(
acc
[
qwarpid
],
num_warps
);
if
(
lane
%
(
num_warps
*
2
)
==
0
)
{
oval
=
__float22s2_rn
<
scalar2_t
>
(
make_float2
(
acc
[
qwarpid
],
oval2
));
c
[
blockIdx
.
x
*
NUM_A_ROWS_PER_BLOCK
/
2
+
qwarpid
/
2
]
=
oval
;
}
}
}
torch
::
Tensor
LLMM1
(
at
::
Tensor
&
in_a
,
at
::
Tensor
&
in_b
,
const
int64_t
rows_per_block
)
{
auto
M
=
in_a
.
size
(
0
);
auto
K
=
in_a
.
size
(
1
);
auto
N
=
in_b
.
size
(
0
);
TORCH_CHECK
(
N
==
1
,
"Row number of activation tensor must be 1."
);
TORCH_CHECK
(
in_a
.
dtype
()
==
in_b
.
dtype
());
TORCH_CHECK
(
in_b
.
dtype
()
==
torch
::
kFloat16
||
in_b
.
dtype
()
==
torch
::
kBFloat16
);
auto
out_c
=
torch
::
empty
(
{
N
,
M
},
torch
::
TensorOptions
().
dtype
(
in_b
.
dtype
()).
device
(
in_b
.
device
()));
// NUM_TREADS need to be a multiple of WARP_SIZE, as we are using warp shuffle
// operations.
const
int
NUM_THREADS
=
K
*
2
/
16
%
WARP_SIZE
==
0
?
K
*
2
/
16
:
K
*
2
/
16
+
(
WARP_SIZE
-
K
*
2
/
16
%
WARP_SIZE
);
int
NUM_BLOCKS
=
M
/
rows_per_block
;
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
in_b
));
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
// call the kernel function...
AT_DISPATCH_REDUCED_FLOATING_TYPES
(
in_b
.
scalar_type
(),
"LLGemm1"
,
[
&
]
{
auto
a_ptr
=
in_a
.
data_ptr
<
scalar_t
>
();
auto
b_ptr
=
in_b
.
data_ptr
<
scalar_t
>
();
auto
c_ptr
=
out_c
.
data_ptr
<
scalar_t
>
();
if
(
rows_per_block
==
2
)
{
LLGemm1_kernel
<
scalar_t
,
2
>
<<<
NUM_BLOCKS
,
NUM_THREADS
,
0
,
stream
>>>
(
a_ptr
,
b_ptr
,
c_ptr
,
K
);
}
else
if
(
rows_per_block
==
4
)
{
LLGemm1_kernel
<
scalar_t
,
4
>
<<<
NUM_BLOCKS
,
NUM_THREADS
,
0
,
stream
>>>
(
a_ptr
,
b_ptr
,
c_ptr
,
K
);
}
else
if
(
rows_per_block
==
8
)
{
LLGemm1_kernel
<
scalar_t
,
8
>
<<<
NUM_BLOCKS
,
NUM_THREADS
,
0
,
stream
>>>
(
a_ptr
,
b_ptr
,
c_ptr
,
K
);
}
else
if
(
rows_per_block
==
16
)
{
LLGemm1_kernel
<
scalar_t
,
16
>
<<<
NUM_BLOCKS
,
NUM_THREADS
,
0
,
stream
>>>
(
a_ptr
,
b_ptr
,
c_ptr
,
K
);
}
else
{
NUM_BLOCKS
=
M
/
4
;
LLGemm1_kernel
<
scalar_t
,
4
>
<<<
NUM_BLOCKS
,
NUM_THREADS
,
0
,
stream
>>>
(
a_ptr
,
b_ptr
,
c_ptr
,
K
);
}
});
return
out_c
;
}
#define DOT2C(V0, V2, V3) \
if constexpr (std::is_same_v<scalar_t, half>) { \
asm("v_dot2c_f32_f16 %0, %2, %3" : "=v"(V0) : "0"(V0), "v"(V2), "v"(V3)); \
} else if constexpr (std::is_same_v<scalar_t, __hip_bfloat16>) { \
float2 s = __bfloat1622float2(*((__hip_bfloat162*)(&(V2)))) * \
__bfloat1622float2(*((__hip_bfloat162*)(&(V3)))); \
V0 += (s.x + s.y); \
}
#if defined(__HIP__MI300_MI250__) // TODO: Add NAVI support
// This version targets cases where A[] fits LDS capacity
template
<
typename
scalar_t
,
int
THRDS
,
int
YTILE
,
int
WvPrGrp
,
int
A_CHUNK
,
int
UNRL
,
int
N
>
__global__
void
__launch_bounds__
(
WvPrGrp
*
THRDS
)
wvSplitK_hf_sml_
(
const
int
K
,
const
int
M
,
const
scalar_t
*
B
,
const
scalar_t
*
__restrict__
A
,
scalar_t
*
C
,
const
int
_WvPrGrp
,
const
int
CuCount
)
{
using
scalar8
=
__attribute__
((
__vector_size__
((
A_CHUNK
/
2
)
*
sizeof
(
float
))))
float
;
union
bigType
{
scalar_t
h
[
A_CHUNK
];
float
f
[
A_CHUNK
/
2
];
float2
f2
[
A_CHUNK
/
4
];
double
d
[
A_CHUNK
/
4
];
scalar8
h8
;
};
//----------------------------------------------------
// Reserving 64 KB of LDS to have 1 WG / CU
// Goal is to bring the activation matrix A to the LDS
// and use it across the lifetime of the work group
// TODO: When activation matrix is larger than 64 KB
// then this is not goint to work!
//----------------------------------------------------
__shared__
scalar_t
s
[
1024
*
32
];
//----------------------------------------------------
// Fetch the activation matrix to LDS
// Loop iteration:
// - Each thread (lane) is fetching 8 elements (A_Chunk)
// - Each wave will fetch 64*8=> 512 elements
// - Each WG will fetch 512 * 16 => 8K elements
// - Then the WG will move to another 8 K elements
// TODO: Logic below will only work when K is multiple of 8
//----------------------------------------------------
for
(
uint32_t
k
=
0
;
k
<
min
(
K
*
N
,
32
*
1024
);
k
+=
THRDS
*
WvPrGrp
*
A_CHUNK
)
{
uint32_t
k_in
=
k
+
((
threadIdx
.
y
*
THRDS
+
threadIdx
.
x
)
*
A_CHUNK
);
if
(
k_in
>=
min
(
K
*
N
,
32
*
1024
))
break
;
*
((
bigType
*
)(
&
s
[
k_in
]))
=
*
((
bigType
*
)(
&
A
[
k_in
]));
}
__syncthreads
();
if
(
threadIdx
.
y
>=
_WvPrGrp
)
return
;
uint32_t
m
=
(
blockIdx
.
x
*
_WvPrGrp
+
(
threadIdx
.
y
%
_WvPrGrp
))
*
YTILE
;
float
sum
[
N
][
YTILE
];
//----------------------------------------------------
// Each wave works on a single column of weight matrix.
// There are 16 waves per WG, and hence, each WG is
// working on 16 columns of weight matrix. Moreover,
// we tile in column direction by YTILE, so when YTILE=1
// the above math is right, however, when YTILE=2 then
// each wave will be working on 2 columns and WG will
// be working on 32 columns.
//
// Top level loop that makes WGs persistent!
// - WGs iterates across columns of weight matrix
// - Each wave within WG works on a given column(s)
// - After completing first set of columns, WGs start
// working on the next set of available columns
//----------------------------------------------------
while
(
m
<
M
)
{
//----------------------------------------------------
// 'sum' accumulates the matrix A x B computation
// split across 64 lanes.
//
// YTILE represents how many column of weight matrix
// are being worked on by each wave.
//----------------------------------------------------
for
(
int
i
=
0
;
i
<
YTILE
;
i
++
)
for
(
int
n
=
0
;
n
<
N
;
n
++
)
sum
[
n
][
i
]
=
0
;
bigType
bigA
[
N
][
UNRL
];
bigType
bigB
[
YTILE
][
UNRL
];
//----------------------------------------------------
// Fetch weight matrix B in interleaved K-split!
// - Each thread (lane) is fetching 8 elements (A_Chunk)
// - Each wave will fetch 64*8=> 512 elements (1024B)
// - YTILE represents the number of column being serviced
// by wave
// - Loop for fetching weight matrix (B) are unrolled
//
// Fetch activation matrix A from LDS
// - Loop for fetching activation matrix (A) are unrolled
//
// Finally, do the matrix multiplication in an unrolled
// fashion. This provides lot of food for compiler
// scheduling.
//
// TODO: Logic below will only work when K is multiple of 8
//----------------------------------------------------
// for (uint32_t k1 = 0; k1 < K; k1 += THRDS * A_CHUNK * UNRL) {
for
(
uint32_t
k1
=
0
;
k1
<
K
;
k1
+=
THRDS
*
A_CHUNK
*
UNRL
)
{
// Fetch the weight matrix from memory!
#pragma unroll
for
(
uint32_t
k2
=
0
;
k2
<
UNRL
;
k2
++
)
{
uint32_t
k
=
k1
+
k2
*
THRDS
*
A_CHUNK
;
uint32_t
k_
=
k
+
threadIdx
.
x
*
A_CHUNK
;
if
(
k_
>=
K
)
break
;
const
scalar_t
*
B_
=
&
B
[(
m
+
0
)
*
K
+
k_
];
bigB
[
0
][
k2
].
h8
=
(
loadnt
((
scalar8
*
)(
&
B_
[
0
*
K
])));
//----------------------------------------------------
// The following code with YTILE > 1 has to be deleted
//----------------------------------------------------
if
constexpr
(
YTILE
>=
2
)
bigB
[
1
][
k2
].
h8
=
(
loadnt
((
scalar8
*
)(
&
B_
[
1
*
K
])));
if
constexpr
(
YTILE
>=
3
)
bigB
[
2
][
k2
].
h8
=
(
loadnt
((
scalar8
*
)(
&
B_
[
2
*
K
])));
if
constexpr
(
YTILE
>=
4
)
bigB
[
3
][
k2
].
h8
=
(
loadnt
((
scalar8
*
)(
&
B_
[
3
*
K
])));
if
constexpr
(
YTILE
>=
5
)
bigB
[
4
][
k2
].
h8
=
(
loadnt
((
scalar8
*
)(
&
B_
[
4
*
K
])));
if
constexpr
(
YTILE
>=
6
)
bigB
[
5
][
k2
].
h8
=
(
loadnt
((
scalar8
*
)(
&
B_
[
5
*
K
])));
if
constexpr
(
YTILE
>=
7
)
bigB
[
6
][
k2
].
h8
=
(
loadnt
((
scalar8
*
)(
&
B_
[
6
*
K
])));
if
constexpr
(
YTILE
>=
8
)
bigB
[
7
][
k2
].
h8
=
(
loadnt
((
scalar8
*
)(
&
B_
[
7
*
K
])));
}
// Fetch activation matrix from either just LDS or from both LDS / memory
#pragma unroll
for
(
uint32_t
k2
=
0
;
k2
<
UNRL
;
k2
++
)
{
uint32_t
k
=
k1
+
k2
*
THRDS
*
A_CHUNK
;
uint32_t
k_
=
k
+
threadIdx
.
x
*
A_CHUNK
;
if
(
k_
>=
K
)
break
;
// Fetch A activation matrix in interleaved fashion from LDS or memory
for
(
int
n
=
0
;
n
<
N
;
n
++
)
{
bigA
[
n
][
k2
]
=
*
((
const
bigType
*
)(
&
(
s
[
k_
+
K
*
n
])));
}
}
// Do the matrix multiplication in interleaved manner
#pragma unroll
for
(
uint32_t
k2
=
0
;
k2
<
UNRL
;
k2
++
)
{
uint32_t
k
=
k1
+
k2
*
THRDS
*
A_CHUNK
;
uint32_t
k_
=
k
+
threadIdx
.
x
*
A_CHUNK
;
if
(
k_
>=
K
)
break
;
// Do the matrix multiplication of activation and weight matrix
// - Remember the accumulation is happening for K-split of 64!
#pragma unroll
for
(
uint32_t
n
=
0
;
n
<
N
;
n
++
)
{
#pragma unroll
for
(
uint32_t
b
=
0
;
b
<
A_CHUNK
/
2
;
b
++
)
{
DOT2C
(
sum
[
n
][
0
],
bigA
[
n
][
k2
].
f
[
b
],
bigB
[
0
][
k2
].
f
[
b
])
//----------------------------------------------------
// The following code with YTILE > 1
//----------------------------------------------------
if
constexpr
(
YTILE
>=
2
)
{
DOT2C
(
sum
[
n
][
1
],
bigA
[
n
][
k2
].
f
[
b
],
bigB
[
1
][
k2
].
f
[
b
]);
}
if
constexpr
(
YTILE
>=
3
)
{
DOT2C
(
sum
[
n
][
2
],
bigA
[
n
][
k2
].
f
[
b
],
bigB
[
2
][
k2
].
f
[
b
]);
}
if
constexpr
(
YTILE
>=
4
)
{
DOT2C
(
sum
[
n
][
3
],
bigA
[
n
][
k2
].
f
[
b
],
bigB
[
3
][
k2
].
f
[
b
]);
}
if
constexpr
(
YTILE
>=
5
)
{
DOT2C
(
sum
[
n
][
4
],
bigA
[
n
][
k2
].
f
[
b
],
bigB
[
4
][
k2
].
f
[
b
]);
}
if
constexpr
(
YTILE
>=
6
)
{
DOT2C
(
sum
[
n
][
5
],
bigA
[
n
][
k2
].
f
[
b
],
bigB
[
5
][
k2
].
f
[
b
]);
}
if
constexpr
(
YTILE
>=
7
)
{
DOT2C
(
sum
[
n
][
6
],
bigA
[
n
][
k2
].
f
[
b
],
bigB
[
6
][
k2
].
f
[
b
]);
}
if
constexpr
(
YTILE
>=
8
)
{
DOT2C
(
sum
[
n
][
7
],
bigA
[
n
][
k2
].
f
[
b
],
bigB
[
7
][
k2
].
f
[
b
]);
}
}
}
}
}
//----------------------------------------------------
// Final reduction step using shuffle
//----------------------------------------------------
for
(
int
n
=
0
;
n
<
N
;
n
++
)
{
for
(
int
y
=
0
;
y
<
YTILE
;
y
++
)
{
asm
(
"s_nop 0
\n\t
v_add_f32 %0, %2, %3 row_shr:8 bound_ctrl:0 "
:
"=v"
(
sum
[
n
][
y
])
:
"0"
(
sum
[
n
][
y
]),
"v"
(
sum
[
n
][
y
]),
"v"
(
sum
[
n
][
y
]));
asm
(
"s_nop 0
\n\t
v_add_f32 %0, %2, %3 row_shr:4 bound_ctrl:0 "
:
"=v"
(
sum
[
n
][
y
])
:
"0"
(
sum
[
n
][
y
]),
"v"
(
sum
[
n
][
y
]),
"v"
(
sum
[
n
][
y
]));
asm
(
"s_nop 0
\n\t
v_add_f32 %0, %2, %3 row_shr:2 bound_ctrl:0 "
:
"=v"
(
sum
[
n
][
y
])
:
"0"
(
sum
[
n
][
y
]),
"v"
(
sum
[
n
][
y
]),
"v"
(
sum
[
n
][
y
]));
asm
(
"s_nop 0
\n\t
v_add_f32 %0, %2, %3 wave_shr:1 bound_ctrl:0"
:
"=v"
(
sum
[
n
][
y
])
:
"0"
(
sum
[
n
][
y
]),
"v"
(
sum
[
n
][
y
]),
"v"
(
sum
[
n
][
y
]));
asm
(
"s_nop 0
\n\t
v_add_f32 %0, %2, %3 row_bcast:15 bound_ctrl:0"
:
"=v"
(
sum
[
n
][
y
])
:
"0"
(
sum
[
n
][
y
]),
"v"
(
sum
[
n
][
y
]),
"v"
(
sum
[
n
][
y
]));
asm
(
"s_nop 0
\n\t
v_add_f32 %0, %2, %3 row_bcast:31 bound_ctrl:0"
:
"=v"
(
sum
[
n
][
y
])
:
"0"
(
sum
[
n
][
y
]),
"v"
(
sum
[
n
][
y
]),
"v"
(
sum
[
n
][
y
]));
}
}
if
(
threadIdx
.
x
==
63
)
{
for
(
int
n
=
0
;
n
<
N
;
n
++
)
{
for
(
int
i
=
0
;
i
<
YTILE
;
i
++
)
{
// if (commitColumn[i]) C[m + i + n * M] = __float2half(sum[n][i]);
C
[
m
+
i
+
n
*
M
]
=
__float2s
<
scalar_t
>
(
sum
[
n
][
i
]);
}
}
}
m
+=
CuCount
*
_WvPrGrp
*
YTILE
;
}
}
#else // !defined(__HIP__MI300_MI250__) TODO: Add NAVI support
template
<
typename
scalar_t
,
int
THRDS
,
int
YTILE
,
int
WvPrGrp
,
int
A_CHUNK
,
int
UNRL
,
int
N
>
__global__
void
wvSplitK_hf_sml_
(
const
int
K
,
const
int
M
,
const
scalar_t
*
B
,
const
scalar_t
*
__restrict__
A
,
scalar_t
*
C
,
const
int
_WvPrGrp
,
const
int
CuCount
)
{
UNREACHABLE_CODE
}
#endif // defined(__HIP__MI300_MI250__) TODO: Add NAVI support
#if defined(__HIP__MI300_MI250__) // TODO: Add NAVI support
// This version targets cases where A[] marginally exceeds LDS capacity
template
<
typename
scalar_t
,
int
THRDS
,
int
YTILE
,
int
WvPrGrp
,
int
A_CHUNK
,
int
UNRL
,
int
N
>
__global__
void
__launch_bounds__
(
WvPrGrp
*
THRDS
)
wvSplitK_hf_
(
const
int
K
,
const
int
M
,
const
scalar_t
*
B
,
const
scalar_t
*
__restrict__
A
,
scalar_t
*
C
,
const
int
_WvPrGrp
,
const
int
CuCount
)
{
using
scalar8
=
__attribute__
((
__vector_size__
((
A_CHUNK
/
2
)
*
sizeof
(
float
))))
float
;
union
bigType
{
scalar_t
h
[
A_CHUNK
];
float
f
[
A_CHUNK
/
2
];
float2
f2
[
A_CHUNK
/
4
];
double
d
[
A_CHUNK
/
4
];
scalar8
h8
;
};
//----------------------------------------------------
// Reserving 64 KB of LDS to have 1 WG / CU
// Goal is to bring the activation matrix A to the LDS
// and use it across the lifetime of the work group
// TODO: When activation matrix is larger than 64 KB
// then this is not goint to work!
//----------------------------------------------------
__shared__
scalar_t
s
[
1024
*
32
];
//----------------------------------------------------
// Computation of columns that need to be committed to memory!
//----------------------------------------------------
uint32_t
commitColumn
[
YTILE
];
for
(
uint32_t
i
=
0
;
i
<
YTILE
;
i
++
)
{
commitColumn
[
i
]
=
1
;
}
//----------------------------------------------------
// Indexing function into the column of weight matrix B
// Algorithm does 64 lane k-splitting / wave and uses
// WG ID and Thread ID to find the index.
//----------------------------------------------------
// int _WvPrGrp = mindiv(N, CuCount * YTILE, WvPrGrp);
uint32_t
m
=
(
blockIdx
.
x
*
_WvPrGrp
+
threadIdx
.
y
)
*
YTILE
;
// Check whether there will be fragmenation!
// This will happen only for the last wave!
if
(
m
<
M
&&
(
m
+
YTILE
)
>=
M
)
{
uint32_t
startColumn
=
M
-
YTILE
;
for
(
uint32_t
i
=
0
;
i
<
(
m
-
startColumn
);
i
++
)
{
commitColumn
[
i
]
=
0
;
}
m
=
startColumn
;
}
//----------------------------------------------------
// Fetch the activation matrix to LDS
// Loop iteration:
// - Each thread (lane) is fetching 8 elements (A_Chunk)
// - Each wave will fetch 64*8=> 512 elements
// - Each WG will fetch 512 * 16 => 8K elements
// - Then the WG will move to another 8 K elements
// TODO: Logic below will only work when K is multiple of 8
//----------------------------------------------------
for
(
uint32_t
k
=
0
;
k
<
min
(
K
*
N
,
32
*
1024
);
k
+=
THRDS
*
WvPrGrp
*
A_CHUNK
)
{
uint32_t
k_in
=
k
+
((
threadIdx
.
y
*
THRDS
+
threadIdx
.
x
)
*
A_CHUNK
);
if
(
k_in
>=
min
(
K
*
N
,
32
*
1024
))
break
;
*
((
bigType
*
)(
&
s
[
k_in
]))
=
*
((
bigType
*
)(
&
A
[
k_in
]));
}
__syncthreads
();
if
(
threadIdx
.
y
>=
_WvPrGrp
)
return
;
float
sum
[
N
][
YTILE
];
//----------------------------------------------------
// Each wave works on a single column of weight matrix.
// There are 16 waves per WG, and hence, each WG is
// working on 16 columns of weight matrix. Moreover,
// we tile in column direction by YTILE, so when YTILE=1
// the above math is right, however, when YTILE=2 then
// each wave will be working on 2 columns and WG will
// be working on 32 columns.
//
// Top level loop that makes WGs persistent!
// - WGs iterates across columns of weight matrix
// - Each wave within WG works on a given column(s)
// - After completing first set of columns, WGs start
// working on the next set of available columns
//----------------------------------------------------
while
(
m
<
M
)
{
//----------------------------------------------------
// 'sum' accumulates the matrix A x B computation
// split across 64 lanes.
//
// YTILE represents how many column of weight matrix
// are being worked on by each wave.
//----------------------------------------------------
for
(
int
i
=
0
;
i
<
YTILE
;
i
++
)
for
(
int
n
=
0
;
n
<
N
;
n
++
)
sum
[
n
][
i
]
=
0
;
bigType
bigA
[
N
][
UNRL
];
bigType
bigB
[
YTILE
][
UNRL
];
//----------------------------------------------------
// Fetch weight matrix B in interleaved K-split!
// - Each thread (lane) is fetching 8 elements (A_Chunk)
// - Each wave will fetch 64*8=> 512 elements (1024B)
// - YTILE represents the number of column being serviced
// by wave
// - Loop for fetching weight matrix (B) are unrolled
//
// Fetch activation matrix A from LDS
// - Loop for fetching activation matrix (A) are unrolled
//
// Finally, do the matrix multiplication in an unrolled
// fashion. This provides lot of food for compiler
// scheduling.
//
// TODO: Logic below will only work when K is multiple of 8
//----------------------------------------------------
for
(
uint32_t
k1
=
0
;
k1
<
K
;
k1
+=
THRDS
*
A_CHUNK
*
UNRL
)
{
// Fetch the weight matrix from memory!
#pragma unroll
for
(
uint32_t
k2
=
0
;
k2
<
UNRL
;
k2
++
)
{
uint32_t
k
=
k1
+
k2
*
THRDS
*
A_CHUNK
;
uint32_t
k_
=
k
+
threadIdx
.
x
*
A_CHUNK
;
if
(
k_
>=
K
)
break
;
const
scalar_t
*
B_
=
&
B
[(
m
+
0
)
*
K
+
k_
];
bigB
[
0
][
k2
].
h8
=
(
loadnt
((
scalar8
*
)(
&
B_
[
0
*
K
])));
//----------------------------------------------------
// The following code with YTILE > 1 has to be deleted
//----------------------------------------------------
if
constexpr
(
YTILE
>=
2
)
bigB
[
1
][
k2
].
h8
=
(
loadnt
((
scalar8
*
)(
&
B_
[
1
*
K
])));
if
constexpr
(
YTILE
>=
3
)
bigB
[
2
][
k2
].
h8
=
(
loadnt
((
scalar8
*
)(
&
B_
[
2
*
K
])));
if
constexpr
(
YTILE
>=
4
)
bigB
[
3
][
k2
].
h8
=
(
loadnt
((
scalar8
*
)(
&
B_
[
3
*
K
])));
if
constexpr
(
YTILE
>=
5
)
bigB
[
4
][
k2
].
h8
=
(
loadnt
((
scalar8
*
)(
&
B_
[
4
*
K
])));
if
constexpr
(
YTILE
>=
6
)
bigB
[
5
][
k2
].
h8
=
(
loadnt
((
scalar8
*
)(
&
B_
[
5
*
K
])));
if
constexpr
(
YTILE
>=
7
)
bigB
[
6
][
k2
].
h8
=
(
loadnt
((
scalar8
*
)(
&
B_
[
6
*
K
])));
if
constexpr
(
YTILE
>=
8
)
bigB
[
7
][
k2
].
h8
=
(
loadnt
((
scalar8
*
)(
&
B_
[
7
*
K
])));
}
// Fetch activation matrix from either just LDS or from both LDS / memory
#pragma unroll
for
(
uint32_t
k2
=
0
;
k2
<
UNRL
;
k2
++
)
{
uint32_t
k
=
k1
+
k2
*
THRDS
*
A_CHUNK
;
uint32_t
k_
=
k
+
threadIdx
.
x
*
A_CHUNK
;
if
(
k_
>=
K
)
break
;
// Fetch A activation matrix in interleaved fashion from LDS or memory
for
(
int
n
=
0
;
n
<
N
;
n
++
)
{
if
(
k_
+
K
*
n
<
32
*
1024
)
bigA
[
n
][
k2
]
=
*
((
const
bigType
*
)(
&
(
s
[
k_
+
K
*
n
])));
else
bigA
[
n
][
k2
]
=
*
((
const
bigType
*
)(
&
(
A
[
k_
+
K
*
n
])));
}
}
// Do the matrix multiplication in interleaved manner
#pragma unroll
for
(
uint32_t
n
=
0
;
n
<
N
;
n
++
)
{
#pragma unroll
for
(
uint32_t
k2
=
0
;
k2
<
UNRL
;
k2
++
)
{
uint32_t
k
=
k1
+
k2
*
THRDS
*
A_CHUNK
;
uint32_t
k_
=
k
+
threadIdx
.
x
*
A_CHUNK
;
if
(
k_
>=
K
)
break
;
// Do the matrix multiplication of activation and weight matrix
// - Remember the accumulation is happening for K-split of 64!
#pragma unroll
for
(
uint32_t
b
=
0
;
b
<
A_CHUNK
/
2
;
b
++
)
{
DOT2C
(
sum
[
n
][
0
],
bigA
[
n
][
k2
].
f
[
b
],
bigB
[
0
][
k2
].
f
[
b
]);
//----------------------------------------------------
// The following code with YTILE > 1
//----------------------------------------------------
if
constexpr
(
YTILE
>=
2
)
{
DOT2C
(
sum
[
n
][
1
],
bigA
[
n
][
k2
].
f
[
b
],
bigB
[
1
][
k2
].
f
[
b
]);
}
if
constexpr
(
YTILE
>=
3
)
{
DOT2C
(
sum
[
n
][
2
],
bigA
[
n
][
k2
].
f
[
b
],
bigB
[
2
][
k2
].
f
[
b
]);
}
if
constexpr
(
YTILE
>=
4
)
{
DOT2C
(
sum
[
n
][
3
],
bigA
[
n
][
k2
].
f
[
b
],
bigB
[
3
][
k2
].
f
[
b
]);
}
if
constexpr
(
YTILE
>=
5
)
{
DOT2C
(
sum
[
n
][
4
],
bigA
[
n
][
k2
].
f
[
b
],
bigB
[
4
][
k2
].
f
[
b
]);
}
if
constexpr
(
YTILE
>=
6
)
{
DOT2C
(
sum
[
n
][
5
],
bigA
[
n
][
k2
].
f
[
b
],
bigB
[
5
][
k2
].
f
[
b
]);
}
if
constexpr
(
YTILE
>=
7
)
{
DOT2C
(
sum
[
n
][
6
],
bigA
[
n
][
k2
].
f
[
b
],
bigB
[
6
][
k2
].
f
[
b
]);
}
if
constexpr
(
YTILE
>=
8
)
{
DOT2C
(
sum
[
n
][
7
],
bigA
[
n
][
k2
].
f
[
b
],
bigB
[
7
][
k2
].
f
[
b
]);
}
}
}
}
}
//----------------------------------------------------
// Final reduction step using shuffle
//----------------------------------------------------
for
(
int
n
=
0
;
n
<
N
;
n
++
)
{
for
(
int
y
=
0
;
y
<
YTILE
;
y
++
)
{
asm
(
"s_nop 0
\n\t
v_add_f32 %0, %2, %3 row_shr:8 bound_ctrl:0 "
:
"=v"
(
sum
[
n
][
y
])
:
"0"
(
sum
[
n
][
y
]),
"v"
(
sum
[
n
][
y
]),
"v"
(
sum
[
n
][
y
]));
asm
(
"s_nop 0
\n\t
v_add_f32 %0, %2, %3 row_shr:4 bound_ctrl:0 "
:
"=v"
(
sum
[
n
][
y
])
:
"0"
(
sum
[
n
][
y
]),
"v"
(
sum
[
n
][
y
]),
"v"
(
sum
[
n
][
y
]));
asm
(
"s_nop 0
\n\t
v_add_f32 %0, %2, %3 row_shr:2 bound_ctrl:0 "
:
"=v"
(
sum
[
n
][
y
])
:
"0"
(
sum
[
n
][
y
]),
"v"
(
sum
[
n
][
y
]),
"v"
(
sum
[
n
][
y
]));
asm
(
"s_nop 0
\n\t
v_add_f32 %0, %2, %3 wave_shr:1 bound_ctrl:0"
:
"=v"
(
sum
[
n
][
y
])
:
"0"
(
sum
[
n
][
y
]),
"v"
(
sum
[
n
][
y
]),
"v"
(
sum
[
n
][
y
]));
asm
(
"s_nop 0
\n\t
v_add_f32 %0, %2, %3 row_bcast:15 bound_ctrl:0"
:
"=v"
(
sum
[
n
][
y
])
:
"0"
(
sum
[
n
][
y
]),
"v"
(
sum
[
n
][
y
]),
"v"
(
sum
[
n
][
y
]));
asm
(
"s_nop 0
\n\t
v_add_f32 %0, %2, %3 row_bcast:31 bound_ctrl:0"
:
"=v"
(
sum
[
n
][
y
])
:
"0"
(
sum
[
n
][
y
]),
"v"
(
sum
[
n
][
y
]),
"v"
(
sum
[
n
][
y
]));
}
}
if
(
threadIdx
.
x
==
63
)
{
for
(
int
n
=
0
;
n
<
N
;
n
++
)
{
for
(
int
i
=
0
;
i
<
YTILE
;
i
++
)
{
if
(
commitColumn
[
i
])
C
[
m
+
i
+
n
*
M
]
=
__float2s
<
scalar_t
>
(
sum
[
n
][
i
]);
}
}
}
m
+=
CuCount
*
_WvPrGrp
*
YTILE
;
// Check whether there will be fragmenation!
// This will happen only for the last wave!
if
(
m
<
M
&&
(
m
+
YTILE
)
>=
M
)
{
uint32_t
startColumn
=
M
-
YTILE
;
for
(
uint32_t
i
=
0
;
i
<
(
m
-
startColumn
);
i
++
)
{
commitColumn
[
i
]
=
0
;
}
m
=
startColumn
;
}
}
}
#else // !defined(__HIP__MI300_MI250__) TODO: Add NAVI support
template
<
typename
scalar_t
,
int
THRDS
,
int
YTILE
,
int
WvPrGrp
,
int
A_CHUNK
,
int
UNRL
,
int
N
>
__global__
void
wvSplitK_hf_
(
const
int
K
,
const
int
M
,
const
scalar_t
*
B
,
const
scalar_t
*
__restrict__
A
,
scalar_t
*
C
,
const
int
_WvPrGrp
,
const
int
CuCount
)
{
UNREACHABLE_CODE
}
#endif // defined(__HIP__MI300_MI250__) TODO: Add NAVI support
#if defined(__HIP__MI300_MI250__) // TODO: Add NAVI support
// This version targets big A[] cases, where it is much larger than LDS capacity
template
<
typename
scalar_t
,
int
THRDS
,
int
YTILE
,
int
WvPrGrp
,
int
A_CHUNK
,
int
UNRL
,
int
N
>
__global__
void
__launch_bounds__
(
WvPrGrp
*
THRDS
)
wvSplitK_hf_big_
(
const
int
K
,
const
int
M
,
const
scalar_t
*
B
,
const
scalar_t
*
__restrict__
A
,
scalar_t
*
C
,
const
int
_WvPrGrp
,
const
int
CuCount
)
{
using
scalar8
=
__attribute__
((
__vector_size__
((
A_CHUNK
/
2
)
*
sizeof
(
float
))))
float
;
union
bigType
{
scalar_t
h
[
A_CHUNK
];
float
f
[
A_CHUNK
/
2
];
float2
f2
[
A_CHUNK
/
4
];
double
d
[
A_CHUNK
/
4
];
scalar8
h8
;
};
//----------------------------------------------------
// Reserving 64 KB of LDS to have 1 WG / CU
// Goal is to bring the activation matrix A to the LDS
// and use it across the lifetime of the work group
// TODO: When activation matrix is larger than 64 KB
// then this is not goint to work!
//----------------------------------------------------
__shared__
scalar_t
s
[
1024
*
32
];
//----------------------------------------------------
// Computation of columns that need to be committed to memory!
//----------------------------------------------------
uint32_t
commitColumn
[
YTILE
];
for
(
uint32_t
i
=
0
;
i
<
YTILE
;
i
++
)
{
commitColumn
[
i
]
=
1
;
}
// int _WvPrGrp = mindiv(N, CuCount * YTILE, WvPrGrp);
if
(
threadIdx
.
y
>=
_WvPrGrp
)
return
;
//----------------------------------------------------
// Indexing function into the column of weight matrix B
// Algorithm does 64 lane k-splitting / wave and uses
// WG ID and Thread ID to find the index.
//----------------------------------------------------
uint32_t
m
=
(
blockIdx
.
x
*
_WvPrGrp
+
threadIdx
.
y
)
*
YTILE
;
// Check whether there will be fragmenation!
// This will happen only for the last wave!
if
(
m
<
M
&&
(
m
+
YTILE
)
>=
M
)
{
uint32_t
startColumn
=
M
-
YTILE
;
for
(
uint32_t
i
=
0
;
i
<
(
m
-
startColumn
);
i
++
)
{
commitColumn
[
i
]
=
0
;
}
m
=
startColumn
;
}
//----------------------------------------------------
// Fetch the activation matrix to LDS
// Loop iteration:
// - Each thread (lane) is fetching 8 elements (A_Chunk)
// - Each wave will fetch 64*8=> 512 elements
// - Each WG will fetch 512 * 16 => 8K elements
// - Then the WG will move to another 8 K elements
// TODO: Logic below will only work when K is multiple of 8
//----------------------------------------------------
#define PCML
#ifndef PCML
for
(
uint32_t
k
=
0
;
k
<
min
(
K
*
N
,
32
*
1024
);
k
+=
THRDS
*
WvPrGrp
*
A_CHUNK
)
{
uint32_t
k_in
=
k
+
((
threadIdx
.
y
*
THRDS
+
threadIdx
.
x
)
*
A_CHUNK
);
if
(
k_in
>=
min
(
K
*
N
,
32
*
1024
))
break
;
*
((
bigType
*
)(
&
s
[
k_in
]))
=
*
((
bigType
*
)(
&
A
[
k_in
]));
}
__syncthreads
();
#endif
#define TUC (THRDS * UNRL * A_CHUNK)
uint32_t
kBase
=
0
;
// find biggest k size that fits in LDS
uint32_t
kFit
=
(
32
*
1024
)
/
N
;
// kFit = (kFit%TWC==0) ? kFit : (kFit-kFit%TWC+TWC); //round up to multiple
// of TUC
kFit
=
(
kFit
%
TUC
==
0
)
?
kFit
:
(
kFit
-
kFit
%
TUC
);
// round up to multiple of TUC
// if (kFit == 0) kFit = TUC;
kFit
=
min
(
kFit
,
K
);
float
sum
[
N
][
YTILE
];
//----------------------------------------------------
// Each wave works on a single column of weight matrix.
// There are 16 waves per WG, and hence, each WG is
// working on 16 columns of weight matrix. Moreover,
// we tile in column direction by YTILE, so when YTILE=1
// the above math is right, however, when YTILE=2 then
// each wave will be working on 2 columns and WG will
// be working on 32 columns.
//
// Top level loop that makes WGs persistent!
// - WGs iterates across columns of weight matrix
// - Each wave within WG works on a given column(s)
// - After completing first set of columns, WGs start
// working on the next set of available columns
//----------------------------------------------------
#ifdef PCML
int
YW
=
(
YTILE
*
_WvPrGrp
);
uint32_t
Mrndp
=
(
M
%
YW
==
0
)
?
M
:
(
M
-
M
%
YW
+
YW
);
while
(
m
<
Mrndp
)
{
#else
while
(
m
<
M
)
{
#endif
//----------------------------------------------------
// 'sum' accumulates the matrix A x B computation
// split across 64 lanes.
//
// YTILE represents how many column of weight matrix
// are being worked on by each wave.
//----------------------------------------------------
for
(
int
i
=
0
;
i
<
YTILE
;
i
++
)
for
(
int
n
=
0
;
n
<
N
;
n
++
)
sum
[
n
][
i
]
=
0
;
bigType
bigA
[
N
][
UNRL
];
bigType
bigB
[
YTILE
][
UNRL
];
//----------------------------------------------------
// Fetch weight matrix B in interleaved K-split!
// - Each thread (lane) is fetching 8 elements (A_Chunk)
// - Each wave will fetch 64*8=> 512 elements (1024B)
// - YTILE represents the number of column being serviced
// by wave
// - Loop for fetching weight matrix (B) are unrolled
//
// Fetch activation matrix A from LDS
// - Loop for fetching activation matrix (A) are unrolled
//
// Finally, do the matrix multiplication in an unrolled
// fashion. This provides lot of food for compiler
// scheduling.
//
// TODO: Logic below will only work when K is multiple of 8
//----------------------------------------------------
for
(
uint32_t
k1
=
0
;
k1
<
K
;
k1
+=
THRDS
*
A_CHUNK
*
UNRL
)
{
#ifdef PCML
if
((
k1
==
0
)
||
(
k1
==
kBase
+
kFit
))
{
// load next chunk of A[] to LDS
if
(
k1
!=
0
)
kBase
+=
kFit
;
__syncthreads
();
for
(
uint32_t
k
=
0
;
k
<
kFit
;
k
+=
THRDS
*
_WvPrGrp
*
A_CHUNK
)
{
uint32_t
kOff
=
k
+
((
threadIdx
.
y
*
THRDS
+
threadIdx
.
x
)
*
A_CHUNK
);
if
(
kBase
+
kOff
>=
K
)
break
;
if
(
kOff
>=
kFit
)
break
;
for
(
uint32_t
n
=
0
;
n
<
N
;
n
++
)
{
uint32_t
k_in
=
kBase
+
n
*
K
+
kOff
;
uint32_t
k_ot
=
n
*
kFit
+
kOff
;
*
((
bigType
*
)(
&
s
[
k_ot
]))
=
*
((
bigType
*
)(
&
A
[
k_in
]));
}
}
__syncthreads
();
}
if
(
m
>=
M
)
continue
;
#endif
// Fetch the weight matrix from memory!
#pragma unroll
for
(
uint32_t
k2
=
0
;
k2
<
UNRL
;
k2
++
)
{
uint32_t
k
=
k1
+
k2
*
THRDS
*
A_CHUNK
;
uint32_t
k_
=
k
+
threadIdx
.
x
*
A_CHUNK
;
if
(
k_
>=
K
)
break
;
const
scalar_t
*
B_
=
&
B
[(
m
+
0
)
*
K
+
k_
];
bigB
[
0
][
k2
].
h8
=
(
loadnt
((
scalar8
*
)(
&
B_
[
0
*
K
])));
//----------------------------------------------------
// The following code with YTILE > 1 has to be deleted
//----------------------------------------------------
if
constexpr
(
YTILE
>=
2
)
bigB
[
1
][
k2
].
h8
=
(
loadnt
((
scalar8
*
)(
&
B_
[
1
*
K
])));
if
constexpr
(
YTILE
>=
3
)
bigB
[
2
][
k2
].
h8
=
(
loadnt
((
scalar8
*
)(
&
B_
[
2
*
K
])));
if
constexpr
(
YTILE
>=
4
)
bigB
[
3
][
k2
].
h8
=
(
loadnt
((
scalar8
*
)(
&
B_
[
3
*
K
])));
if
constexpr
(
YTILE
>=
5
)
bigB
[
4
][
k2
].
h8
=
(
loadnt
((
scalar8
*
)(
&
B_
[
4
*
K
])));
if
constexpr
(
YTILE
>=
6
)
bigB
[
5
][
k2
].
h8
=
(
loadnt
((
scalar8
*
)(
&
B_
[
5
*
K
])));
if
constexpr
(
YTILE
>=
7
)
bigB
[
6
][
k2
].
h8
=
(
loadnt
((
scalar8
*
)(
&
B_
[
6
*
K
])));
if
constexpr
(
YTILE
>=
8
)
bigB
[
7
][
k2
].
h8
=
(
loadnt
((
scalar8
*
)(
&
B_
[
7
*
K
])));
}
// Fetch activation matrix from either just LDS or from both LDS / memory
#pragma unroll
for
(
uint32_t
k2
=
0
;
k2
<
UNRL
;
k2
++
)
{
uint32_t
k
=
k1
+
k2
*
THRDS
*
A_CHUNK
;
uint32_t
k_
=
k
+
threadIdx
.
x
*
A_CHUNK
;
if
(
k_
>=
K
)
break
;
// Fetch A activation matrix in interleaved fashion from LDS or memory
for
(
int
n
=
0
;
n
<
N
;
n
++
)
{
#ifdef PCML
bigA
[
n
][
k2
]
=
*
((
const
bigType
*
)(
&
(
s
[
k_
-
kBase
+
kFit
*
n
])));
#else
if
(
k_
+
K
*
n
<
32
*
1024
)
bigA
[
n
][
k2
]
=
*
((
const
bigType
*
)(
&
(
s
[
k_
+
K
*
n
])));
else
bigA
[
n
][
k2
]
=
*
((
const
bigType
*
)(
&
(
A
[
k_
+
K
*
n
])));
#endif
}
}
// Do the matrix multiplication in interleaved manner
#pragma unroll
for
(
uint32_t
k2
=
0
;
k2
<
UNRL
;
k2
++
)
{
uint32_t
k
=
k1
+
k2
*
THRDS
*
A_CHUNK
;
uint32_t
k_
=
k
+
threadIdx
.
x
*
A_CHUNK
;
if
(
k_
>=
K
)
break
;
#pragma unroll
for
(
uint32_t
n
=
0
;
n
<
N
;
n
++
)
{
// Do the matrix multiplication of activation and weight matrix
// - Remember the accumulation is happening for K-split of 64!
#pragma unroll
for
(
uint32_t
b
=
0
;
b
<
A_CHUNK
/
2
;
b
++
)
{
DOT2C
(
sum
[
n
][
0
],
bigA
[
n
][
k2
].
f
[
b
],
bigB
[
0
][
k2
].
f
[
b
]);
//----------------------------------------------------
// The following code with YTILE > 1
//----------------------------------------------------
if
constexpr
(
YTILE
>=
2
)
{
DOT2C
(
sum
[
n
][
1
],
bigA
[
n
][
k2
].
f
[
b
],
bigB
[
1
][
k2
].
f
[
b
]);
}
if
constexpr
(
YTILE
>=
3
)
{
DOT2C
(
sum
[
n
][
2
],
bigA
[
n
][
k2
].
f
[
b
],
bigB
[
2
][
k2
].
f
[
b
]);
}
if
constexpr
(
YTILE
>=
4
)
{
DOT2C
(
sum
[
n
][
3
],
bigA
[
n
][
k2
].
f
[
b
],
bigB
[
3
][
k2
].
f
[
b
]);
}
if
constexpr
(
YTILE
>=
5
)
{
DOT2C
(
sum
[
n
][
4
],
bigA
[
n
][
k2
].
f
[
b
],
bigB
[
4
][
k2
].
f
[
b
]);
}
if
constexpr
(
YTILE
>=
6
)
{
DOT2C
(
sum
[
n
][
5
],
bigA
[
n
][
k2
].
f
[
b
],
bigB
[
5
][
k2
].
f
[
b
]);
}
if
constexpr
(
YTILE
>=
7
)
{
DOT2C
(
sum
[
n
][
6
],
bigA
[
n
][
k2
].
f
[
b
],
bigB
[
6
][
k2
].
f
[
b
]);
}
if
constexpr
(
YTILE
>=
8
)
{
DOT2C
(
sum
[
n
][
7
],
bigA
[
n
][
k2
].
f
[
b
],
bigB
[
7
][
k2
].
f
[
b
]);
}
}
}
}
}
#ifdef PCML
if
(
m
>=
M
)
{
m
+=
CuCount
*
_WvPrGrp
*
YTILE
;
kBase
=
0
;
continue
;
}
#endif
//----------------------------------------------------
// Final reduction step using shuffle
//----------------------------------------------------
for
(
int
n
=
0
;
n
<
N
;
n
++
)
{
for
(
int
y
=
0
;
y
<
YTILE
;
y
++
)
{
asm
(
"s_nop 0
\n\t
v_add_f32 %0, %2, %3 row_shr:8 bound_ctrl:0 "
:
"=v"
(
sum
[
n
][
y
])
:
"0"
(
sum
[
n
][
y
]),
"v"
(
sum
[
n
][
y
]),
"v"
(
sum
[
n
][
y
]));
asm
(
"s_nop 0
\n\t
v_add_f32 %0, %2, %3 row_shr:4 bound_ctrl:0 "
:
"=v"
(
sum
[
n
][
y
])
:
"0"
(
sum
[
n
][
y
]),
"v"
(
sum
[
n
][
y
]),
"v"
(
sum
[
n
][
y
]));
asm
(
"s_nop 0
\n\t
v_add_f32 %0, %2, %3 row_shr:2 bound_ctrl:0 "
:
"=v"
(
sum
[
n
][
y
])
:
"0"
(
sum
[
n
][
y
]),
"v"
(
sum
[
n
][
y
]),
"v"
(
sum
[
n
][
y
]));
asm
(
"s_nop 0
\n\t
v_add_f32 %0, %2, %3 wave_shr:1 bound_ctrl:0"
:
"=v"
(
sum
[
n
][
y
])
:
"0"
(
sum
[
n
][
y
]),
"v"
(
sum
[
n
][
y
]),
"v"
(
sum
[
n
][
y
]));
asm
(
"s_nop 0
\n\t
v_add_f32 %0, %2, %3 row_bcast:15 bound_ctrl:0"
:
"=v"
(
sum
[
n
][
y
])
:
"0"
(
sum
[
n
][
y
]),
"v"
(
sum
[
n
][
y
]),
"v"
(
sum
[
n
][
y
]));
asm
(
"s_nop 0
\n\t
v_add_f32 %0, %2, %3 row_bcast:31 bound_ctrl:0"
:
"=v"
(
sum
[
n
][
y
])
:
"0"
(
sum
[
n
][
y
]),
"v"
(
sum
[
n
][
y
]),
"v"
(
sum
[
n
][
y
]));
}
}
if
(
threadIdx
.
x
==
63
)
{
for
(
int
n
=
0
;
n
<
N
;
n
++
)
{
for
(
int
i
=
0
;
i
<
YTILE
;
i
++
)
{
if
(
commitColumn
[
i
])
C
[
m
+
i
+
n
*
M
]
=
__float2s
<
scalar_t
>
(
sum
[
n
][
i
]);
}
}
}
m
+=
CuCount
*
_WvPrGrp
*
YTILE
;
kBase
=
0
;
// Check whether there will be fragmenation!
// This will happen only for the last wave!
if
(
m
<
M
&&
(
m
+
YTILE
)
>=
M
)
{
uint32_t
startColumn
=
M
-
YTILE
;
for
(
uint32_t
i
=
0
;
i
<
(
m
-
startColumn
);
i
++
)
{
commitColumn
[
i
]
=
0
;
}
m
=
startColumn
;
}
}
}
#else // !defined(__HIP__MI300_MI250__) TODO: Add NAVI support
template
<
typename
scalar_t
,
int
THRDS
,
int
YTILE
,
int
WvPrGrp
,
int
A_CHUNK
,
int
UNRL
,
int
N
>
__global__
void
wvSplitK_hf_big_
(
const
int
K
,
const
int
M
,
const
scalar_t
*
B
,
const
scalar_t
*
__restrict__
A
,
scalar_t
*
C
,
const
int
_WvPrGrp
,
const
int
CuCount
)
{
UNREACHABLE_CODE
}
#endif // defined(__HIP__MI300_MI250__) TODO: Add NAVI support
int
mindiv
(
int
N
,
int
div1
,
int
div2
)
{
int
nPrRnd
=
div1
*
div2
;
int
rnds0
=
N
/
nPrRnd
;
nPrRnd
-=
div1
*
3
;
int
rnds3
=
N
/
nPrRnd
;
nPrRnd
-=
div1
;
int
rnds4
=
N
/
nPrRnd
;
nPrRnd
-=
div1
;
int
rnds5
=
N
/
nPrRnd
;
nPrRnd
-=
div1
;
int
rnds6
=
N
/
nPrRnd
;
nPrRnd
-=
div1
;
int
rnds7
=
N
/
nPrRnd
;
nPrRnd
-=
div1
;
int
rnds8
=
N
/
nPrRnd
;
nPrRnd
-=
div1
;
int
rnds9
=
N
/
nPrRnd
;
nPrRnd
-=
div1
;
int
rtn
=
div2
;
if
(
rnds0
==
rnds3
)
rtn
=
div2
-
3
;
if
(
rnds0
==
rnds4
)
rtn
=
div2
-
4
;
if
(
rnds0
==
rnds5
)
rtn
=
div2
-
5
;
if
(
rnds0
==
rnds6
)
rtn
=
div2
-
6
;
if
(
rnds0
==
rnds7
)
rtn
=
div2
-
7
;
if
(
rnds0
==
rnds8
)
rtn
=
div2
-
8
;
if
(
rnds0
==
rnds9
)
rtn
=
div2
-
9
;
return
rtn
;
}
torch
::
Tensor
wvSplitK
(
at
::
Tensor
&
in_a
,
at
::
Tensor
&
in_b
,
const
int64_t
CuCount
)
{
auto
M_in
=
in_a
.
size
(
0
);
auto
K_in
=
in_a
.
size
(
1
);
auto
N_in
=
in_b
.
size
(
0
);
TORCH_CHECK
(
in_a
.
dtype
()
==
in_b
.
dtype
());
TORCH_CHECK
(
K_in
%
8
==
0
,
"k % 8 == 0"
);
TORCH_CHECK
(
in_a
.
dtype
()
==
torch
::
kFloat16
||
in_a
.
dtype
()
==
torch
::
kBFloat16
);
auto
out_c
=
torch
::
empty
(
{
N_in
,
M_in
},
torch
::
TensorOptions
().
dtype
(
in_b
.
dtype
()).
device
(
in_b
.
device
()));
dim3
grid
(
CuCount
);
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
in_a
));
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
#define WVSPLITK(_WvPrGrp, _YTILEs, _YTILEm, _YTILEb, _UNRLs, _UNRLm, _UNRLb, \
_N) \
{ \
dim3 block(64, _WvPrGrp); \
if ((K_in * N_in <= 32 * 1024) && (M_in % _YTILEs == 0)) { \
int __wvPrGrp = mindiv(M_in, CuCount * _YTILEs, _WvPrGrp); \
wvSplitK_hf_sml_<fptype, 64, _YTILEs, _WvPrGrp, 8, _UNRLs, _N> \
<<<grid, block, 0, stream>>>(K_in, M_in, af4, bf4, c, __wvPrGrp, \
CuCount); \
} else if (K_in * N_in <= 32 * 1024 * 1.2) { \
int __wvPrGrp = mindiv(M_in, CuCount * _YTILEm, _WvPrGrp); \
wvSplitK_hf_<fptype, 64, _YTILEm, _WvPrGrp, 8, _UNRLm, _N> \
<<<grid, block, 0, stream>>>(K_in, M_in, af4, bf4, c, __wvPrGrp, \
CuCount); \
} else { \
int __wvPrGrp = mindiv(M_in, CuCount * _YTILEb, _WvPrGrp); \
wvSplitK_hf_big_<fptype, 64, _YTILEb, _WvPrGrp, 8, _UNRLb, _N> \
<<<grid, block, 0, stream>>>(K_in, M_in, af4, bf4, c, __wvPrGrp, \
CuCount); \
} \
}
AT_DISPATCH_REDUCED_FLOATING_TYPES
(
in_b
.
scalar_type
(),
"wvSplitK"
,
[
&
]
{
using
fptype
=
typename
scalar
<
scalar_t
>::
type
;
fptype
*
af4
=
reinterpret_cast
<
fptype
*>
(
in_a
.
data_ptr
());
const
fptype
*
bf4
=
reinterpret_cast
<
const
fptype
*>
(
in_b
.
data_ptr
());
fptype
*
c
=
reinterpret_cast
<
fptype
*>
(
out_c
.
data_ptr
());
switch
(
N_in
)
{
case
1
:
WVSPLITK
(
16
,
2
,
2
,
2
,
2
,
2
,
2
,
1
)
break
;
case
2
:
WVSPLITK
(
16
,
2
,
2
,
2
,
2
,
2
,
2
,
2
)
break
;
case
3
:
WVSPLITK
(
16
,
4
,
7
,
7
,
1
,
1
,
1
,
3
)
break
;
case
4
:
WVSPLITK
(
16
,
4
,
7
,
7
,
1
,
1
,
1
,
4
)
break
;
default:
throw
std
::
runtime_error
(
"Unsupported N value: "
+
std
::
to_string
(
M_in
)
+
","
+
std
::
to_string
(
K_in
)
+
","
+
std
::
to_string
(
N_in
));
}
});
return
out_c
;
}
#if defined(__HIP__MI300__) // TODO: Add NAVI support
template
<
typename
scalar_t
,
typename
fp8_t
,
int
THRDS
,
int
YTILE
,
int
WvPrGrp
,
int
A_CHUNK
,
int
UNRL
,
int
N
>
__global__
void
__launch_bounds__
(
WvPrGrp
*
THRDS
)
wvSplitKQ_hf_sml_
(
const
int
K
,
const
int
Kp
,
const
int
M
,
const
fp8_t
*
B
,
const
fp8_t
*
__restrict__
A
,
scalar_t
*
C
,
const
float
*
__restrict__
s_A
,
const
float
*
__restrict__
s_B
,
const
int
_WvPrGrp
,
const
int
CuCount
)
{
using
scalar8
=
__attribute__
((
__vector_size__
((
A_CHUNK
/
4
)
*
sizeof
(
float
))))
float
;
using
intx2
=
__attribute__
((
__vector_size__
(
2
*
sizeof
(
int
))))
int
;
using
intx4
=
__attribute__
((
__vector_size__
(
4
*
sizeof
(
int
))))
int
;
union
bigType
{
char
f8
[
A_CHUNK
];
char2
c2
[
A_CHUNK
/
2
];
scalar_t
h
[
A_CHUNK
/
2
];
float
f
[
A_CHUNK
/
4
];
int
i
[
A_CHUNK
/
4
];
long
l
[
A_CHUNK
/
8
];
intx4
l2
[
A_CHUNK
/
16
];
scalar8
h8
;
};
__shared__
fp8_t
s
[
1024
*
64
];
for
(
uint32_t
k
=
(
threadIdx
.
y
*
THRDS
+
threadIdx
.
x
)
*
A_CHUNK
;
k
<
min
(
K
*
N
,
64
*
1024
);
k
+=
THRDS
*
WvPrGrp
*
A_CHUNK
)
{
*
((
bigType
*
)(
&
s
[
k
]))
=
*
((
bigType
*
)(
&
A
[
k
]));
}
__syncthreads
();
if
(
threadIdx
.
y
>=
_WvPrGrp
)
return
;
uint32_t
m
=
(
blockIdx
.
x
*
_WvPrGrp
+
(
threadIdx
.
y
%
_WvPrGrp
))
*
YTILE
;
using
floatx16
=
__attribute__
((
__vector_size__
(
16
*
sizeof
(
float
))))
float
;
floatx16
sum
[
N
][
YTILE
];
float
sA
=
*
s_A
;
float
sB
=
*
s_B
;
while
(
m
<
M
)
{
for
(
int
i
=
0
;
i
<
YTILE
;
i
++
)
for
(
int
n
=
0
;
n
<
N
;
n
++
)
sum
[
n
][
i
]
=
{
0.
f
};
bigType
bigA
[
N
][
UNRL
];
bigType
bigB
[
YTILE
][
UNRL
];
for
(
uint32_t
k1
=
0
;
k1
<
K
;
k1
+=
THRDS
*
A_CHUNK
*
UNRL
)
{
#pragma unroll
for
(
uint32_t
k2
=
0
;
k2
<
UNRL
;
k2
++
)
{
#pragma unroll
for
(
uint32_t
n
=
0
;
n
<
N
;
++
n
)
bigA
[
n
][
k2
].
h8
=
{
0.
f
};
#pragma unroll
for
(
uint32_t
y
=
0
;
y
<
YTILE
;
++
y
)
bigB
[
y
][
k2
].
h8
=
{
0.
f
};
}
// Fetch the weight matrix from memory!
#pragma unroll
for
(
uint32_t
k2
=
0
;
k2
<
UNRL
;
k2
++
)
{
uint32_t
k
=
k1
+
k2
*
THRDS
*
A_CHUNK
;
uint32_t
k_
=
k
+
threadIdx
.
x
*
A_CHUNK
;
if
(
k_
>=
K
)
break
;
const
fp8_t
*
B_
=
&
B
[(
m
+
0
)
*
Kp
+
k_
];
#pragma unroll
for
(
uint32_t
y
=
0
;
y
<
YTILE
;
++
y
)
{
bigB
[
y
][
k2
].
h8
=
(
loadnt
((
scalar8
*
)(
&
B_
[
y
*
Kp
])));
}
}
// Fetch activation matrix from either just LDS or from both LDS / memory
#pragma unroll
for
(
uint32_t
k2
=
0
;
k2
<
UNRL
;
k2
++
)
{
uint32_t
k
=
k1
+
k2
*
THRDS
*
A_CHUNK
;
uint32_t
k_
=
k
+
threadIdx
.
x
*
A_CHUNK
;
if
(
k_
>=
K
)
break
;
for
(
int
n
=
0
;
n
<
N
;
n
++
)
{
bigA
[
n
][
k2
]
=
*
((
const
bigType
*
)(
&
(
s
[
k_
+
K
*
n
])));
}
}
// Do the matrix multiplication in interleaved manner
#pragma unroll
for
(
uint32_t
k2
=
0
;
k2
<
UNRL
;
k2
++
)
{
uint32_t
k
=
k1
+
k2
*
THRDS
*
A_CHUNK
;
if
(
k
>=
K
)
break
;
for
(
uint32_t
n
=
0
;
n
<
N
;
n
++
)
{
for
(
int
i
=
0
;
i
<
A_CHUNK
;
i
+=
8
)
{
for
(
int
y
=
0
;
y
<
YTILE
;
++
y
)
{
sum
[
n
][
y
]
=
__builtin_amdgcn_mfma_f32_32x32x16_fp8_fp8
(
bigA
[
n
][
k2
].
l
[
i
/
8
],
bigB
[
y
][
k2
].
l
[
i
/
8
],
sum
[
n
][
y
],
0
,
0
,
0
);
}
}
}
}
}
// Final reduction
for
(
int
n
=
0
;
n
<
N
;
n
++
)
{
for
(
int
y
=
0
;
y
<
YTILE
;
y
++
)
{
float
accm0
=
sum
[
n
][
y
][
0
];
float
accm16
=
sum
[
n
][
y
][
8
];
asm
(
"v_add_f32 %0, %2, %3 row_shl:1 bound_ctrl:0 "
:
"=v"
(
accm0
)
:
"0"
(
accm0
),
"v"
(
sum
[
n
][
y
][
1
]),
"v"
(
accm0
));
asm
(
"v_add_f32 %0, %2, %3 row_shl:1 bound_ctrl:0 "
:
"=v"
(
accm16
)
:
"0"
(
accm16
),
"v"
(
sum
[
n
][
y
][
9
]),
"v"
(
accm16
));
asm
(
"v_add_f32 %0, %2, %3 row_shl:2 bound_ctrl:0 "
:
"=v"
(
accm0
)
:
"0"
(
accm0
),
"v"
(
sum
[
n
][
y
][
2
]),
"v"
(
accm0
));
asm
(
"v_add_f32 %0, %2, %3 row_shl:2 bound_ctrl:0 "
:
"=v"
(
accm16
)
:
"0"
(
accm16
),
"v"
(
sum
[
n
][
y
][
10
]),
"v"
(
accm16
));
asm
(
"v_add_f32 %0, %2, %3 row_shl:3 bound_ctrl:0 "
:
"=v"
(
accm0
)
:
"0"
(
accm0
),
"v"
(
sum
[
n
][
y
][
3
]),
"v"
(
accm0
));
asm
(
"v_add_f32 %0, %2, %3 row_shl:3 bound_ctrl:0 "
:
"=v"
(
accm16
)
:
"0"
(
accm16
),
"v"
(
sum
[
n
][
y
][
11
]),
"v"
(
accm16
));
asm
(
"v_add_f32 %0, %2, %3 row_shl:8 bound_ctrl:0 "
:
"=v"
(
accm0
)
:
"0"
(
accm0
),
"v"
(
sum
[
n
][
y
][
4
]),
"v"
(
accm0
));
asm
(
"v_add_f32 %0, %2, %3 row_shl:8 bound_ctrl:0 "
:
"=v"
(
accm16
)
:
"0"
(
accm16
),
"v"
(
sum
[
n
][
y
][
12
]),
"v"
(
accm16
));
asm
(
"v_add_f32 %0, %2, %3 row_shl:9 bound_ctrl:0 "
:
"=v"
(
accm0
)
:
"0"
(
accm0
),
"v"
(
sum
[
n
][
y
][
5
]),
"v"
(
accm0
));
asm
(
"v_add_f32 %0, %2, %3 row_shl:9 bound_ctrl:0 "
:
"=v"
(
accm16
)
:
"0"
(
accm16
),
"v"
(
sum
[
n
][
y
][
13
]),
"v"
(
accm16
));
asm
(
"v_add_f32 %0, %2, %3 row_shl:10 bound_ctrl:0 "
:
"=v"
(
accm0
)
:
"0"
(
accm0
),
"v"
(
sum
[
n
][
y
][
6
]),
"v"
(
accm0
));
asm
(
"v_add_f32 %0, %2, %3 row_shl:10 bound_ctrl:0 "
:
"=v"
(
accm16
)
:
"0"
(
accm16
),
"v"
(
sum
[
n
][
y
][
14
]),
"v"
(
accm16
));
asm
(
"v_add_f32 %0, %2, %3 row_shl:11 bound_ctrl:0 "
:
"=v"
(
accm0
)
:
"0"
(
accm0
),
"v"
(
sum
[
n
][
y
][
7
]),
"v"
(
accm0
));
asm
(
"v_add_f32 %0, %2, %3 row_shl:11 bound_ctrl:0 "
:
"=v"
(
accm16
)
:
"0"
(
accm16
),
"v"
(
sum
[
n
][
y
][
15
]),
"v"
(
accm16
));
accm0
+=
__shfl
(
accm0
,
36
);
accm16
+=
__shfl
(
accm16
,
52
);
sum
[
n
][
y
][
0
]
=
accm0
+
__shfl
(
accm16
,
16
);
}
}
if
(
threadIdx
.
x
==
0
)
{
for
(
int
n
=
0
;
n
<
N
;
n
++
)
{
for
(
int
y
=
0
;
y
<
YTILE
;
y
++
)
{
C
[
m
+
y
+
n
*
M
]
=
__float2s
<
scalar_t
>
(
sum
[
n
][
y
][
0
]
*
sA
*
sB
);
}
}
}
m
+=
CuCount
*
_WvPrGrp
*
YTILE
;
}
}
#else // !defined(__HIP__MI300__) TODO: Add NAVI support
template
<
typename
scalar_t
,
typename
fp8_t
,
int
THRDS
,
int
YTILE
,
int
WvPrGrp
,
int
A_CHUNK
,
int
UNRL
,
int
N
>
__global__
void
wvSplitKQ_hf_sml_
(
const
int
K
,
const
int
Kp
,
const
int
M
,
const
fp8_t
*
B
,
const
fp8_t
*
__restrict__
A
,
scalar_t
*
C
,
const
float
*
__restrict__
s_A
,
const
float
*
__restrict__
s_B
,
const
int
_WvPrGrp
,
const
int
CuCount
)
{
UNREACHABLE_CODE
}
#endif // defined(__HIP__MI300__) TODO: Add NAVI support
#if defined(__HIP__MI300__) // TODO: Add NAVI support
template
<
typename
scalar_t
,
typename
fp8_t
,
int
THRDS
,
int
YTILE
,
int
WvPrGrp
,
int
A_CHUNK
,
int
UNRL
,
int
N
>
__global__
void
__launch_bounds__
(
WvPrGrp
*
THRDS
)
wvSplitKQ_hf_
(
const
int
K
,
const
int
Kp
,
const
int
M
,
const
fp8_t
*
B
,
const
fp8_t
*
__restrict__
A
,
scalar_t
*
C
,
const
float
*
__restrict__
s_A
,
const
float
*
__restrict__
s_B
,
const
int
_WvPrGrp
,
const
int
CuCount
)
{
using
scalar8
=
__attribute__
((
__vector_size__
((
A_CHUNK
/
4
)
*
sizeof
(
float
))))
float
;
using
intx2
=
__attribute__
((
__vector_size__
(
2
*
sizeof
(
int
))))
int
;
using
intx4
=
__attribute__
((
__vector_size__
(
4
*
sizeof
(
int
))))
int
;
union
bigType
{
char
f8
[
A_CHUNK
];
char2
c2
[
A_CHUNK
/
2
];
scalar_t
h
[
A_CHUNK
/
2
];
float
f
[
A_CHUNK
/
4
];
int
i
[
A_CHUNK
/
4
];
long
l
[
A_CHUNK
/
8
];
intx4
l2
[
A_CHUNK
/
16
];
scalar8
h8
;
};
__shared__
fp8_t
s
[
1024
*
64
];
for
(
uint32_t
k
=
(
threadIdx
.
y
*
THRDS
+
threadIdx
.
x
)
*
A_CHUNK
;
k
<
min
(
K
*
N
,
64
*
1024
);
k
+=
THRDS
*
WvPrGrp
*
A_CHUNK
)
{
*
((
bigType
*
)(
&
s
[
k
]))
=
*
((
bigType
*
)(
&
A
[
k
]));
}
__syncthreads
();
if
(
threadIdx
.
y
>=
_WvPrGrp
)
return
;
uint32_t
m
=
(
blockIdx
.
x
*
_WvPrGrp
+
(
threadIdx
.
y
%
_WvPrGrp
))
*
YTILE
;
using
floatx16
=
__attribute__
((
__vector_size__
(
16
*
sizeof
(
float
))))
float
;
floatx16
sum
[
N
][
YTILE
];
float
sA
=
*
s_A
;
float
sB
=
*
s_B
;
while
(
m
<
M
)
{
for
(
int
i
=
0
;
i
<
YTILE
;
i
++
)
for
(
int
n
=
0
;
n
<
N
;
n
++
)
sum
[
n
][
i
]
=
{
0
};
bigType
bigA
[
N
][
UNRL
];
bigType
bigB
[
YTILE
][
UNRL
];
for
(
uint32_t
k1
=
0
;
k1
<
K
;
k1
+=
THRDS
*
A_CHUNK
*
UNRL
)
{
// Fetch the weight matrix from memory!
#pragma unroll
for
(
uint32_t
k2
=
0
;
k2
<
UNRL
;
k2
++
)
{
uint32_t
k
=
k1
+
k2
*
THRDS
*
A_CHUNK
;
uint32_t
k_
=
k
+
threadIdx
.
x
*
A_CHUNK
;
if
(
k_
>=
K
)
break
;
const
fp8_t
*
B_
=
&
B
[(
m
+
0
)
*
Kp
+
k_
];
for
(
int
y
=
0
;
y
<
YTILE
;
++
y
)
{
if
(
y
+
m
>=
M
)
break
;
// To avoid mem access fault.
bigB
[
y
][
k2
].
h8
=
(
loadnt
((
scalar8
*
)(
&
B_
[
y
*
Kp
])));
}
}
// Fetch activation matrix from either just LDS or from both LDS / memory
#pragma unroll
for
(
uint32_t
k2
=
0
;
k2
<
UNRL
;
k2
++
)
{
uint32_t
k
=
k1
+
k2
*
THRDS
*
A_CHUNK
;
uint32_t
k_
=
k
+
threadIdx
.
x
*
A_CHUNK
;
if
(
k_
>=
K
)
break
;
for
(
int
n
=
0
;
n
<
N
;
n
++
)
{
if
(
k_
+
K
*
n
<
64
*
1024
)
bigA
[
n
][
k2
]
=
*
((
const
bigType
*
)(
&
(
s
[
k_
+
K
*
n
])));
else
bigA
[
n
][
k2
]
=
*
((
const
bigType
*
)(
&
(
A
[
k_
+
K
*
n
])));
}
}
// Do the matrix multiplication in interleaved manner
#pragma unroll
for
(
uint32_t
k2
=
0
;
k2
<
UNRL
;
k2
++
)
{
uint32_t
k
=
k1
+
k2
*
THRDS
*
A_CHUNK
;
uint32_t
k_
=
k
+
threadIdx
.
x
*
A_CHUNK
;
if
(
k_
>=
K
)
break
;
for
(
uint32_t
n
=
0
;
n
<
N
;
n
++
)
{
for
(
int
i
=
0
;
i
<
A_CHUNK
;
i
+=
8
)
{
for
(
int
y
=
0
;
y
<
YTILE
;
++
y
)
{
sum
[
n
][
y
]
=
__builtin_amdgcn_mfma_f32_32x32x16_fp8_fp8
(
bigA
[
n
][
k2
].
l
[
i
/
8
],
bigB
[
y
][
k2
].
l
[
i
/
8
],
sum
[
n
][
y
],
0
,
0
,
0
);
}
}
}
}
}
// Final reduction
for
(
int
n
=
0
;
n
<
N
;
n
++
)
{
for
(
int
y
=
0
;
y
<
YTILE
;
y
++
)
{
float
accm0
=
sum
[
n
][
y
][
0
];
float
accm16
=
sum
[
n
][
y
][
8
];
asm
(
"v_add_f32 %0, %2, %3 row_shl:1 bound_ctrl:0 "
:
"=v"
(
accm0
)
:
"0"
(
accm0
),
"v"
(
sum
[
n
][
y
][
1
]),
"v"
(
accm0
));
asm
(
"v_add_f32 %0, %2, %3 row_shl:1 bound_ctrl:0 "
:
"=v"
(
accm16
)
:
"0"
(
accm16
),
"v"
(
sum
[
n
][
y
][
9
]),
"v"
(
accm16
));
asm
(
"v_add_f32 %0, %2, %3 row_shl:2 bound_ctrl:0 "
:
"=v"
(
accm0
)
:
"0"
(
accm0
),
"v"
(
sum
[
n
][
y
][
2
]),
"v"
(
accm0
));
asm
(
"v_add_f32 %0, %2, %3 row_shl:2 bound_ctrl:0 "
:
"=v"
(
accm16
)
:
"0"
(
accm16
),
"v"
(
sum
[
n
][
y
][
10
]),
"v"
(
accm16
));
asm
(
"v_add_f32 %0, %2, %3 row_shl:3 bound_ctrl:0 "
:
"=v"
(
accm0
)
:
"0"
(
accm0
),
"v"
(
sum
[
n
][
y
][
3
]),
"v"
(
accm0
));
asm
(
"v_add_f32 %0, %2, %3 row_shl:3 bound_ctrl:0 "
:
"=v"
(
accm16
)
:
"0"
(
accm16
),
"v"
(
sum
[
n
][
y
][
11
]),
"v"
(
accm16
));
asm
(
"v_add_f32 %0, %2, %3 row_shl:8 bound_ctrl:0 "
:
"=v"
(
accm0
)
:
"0"
(
accm0
),
"v"
(
sum
[
n
][
y
][
4
]),
"v"
(
accm0
));
asm
(
"v_add_f32 %0, %2, %3 row_shl:8 bound_ctrl:0 "
:
"=v"
(
accm16
)
:
"0"
(
accm16
),
"v"
(
sum
[
n
][
y
][
12
]),
"v"
(
accm16
));
asm
(
"v_add_f32 %0, %2, %3 row_shl:9 bound_ctrl:0 "
:
"=v"
(
accm0
)
:
"0"
(
accm0
),
"v"
(
sum
[
n
][
y
][
5
]),
"v"
(
accm0
));
asm
(
"v_add_f32 %0, %2, %3 row_shl:9 bound_ctrl:0 "
:
"=v"
(
accm16
)
:
"0"
(
accm16
),
"v"
(
sum
[
n
][
y
][
13
]),
"v"
(
accm16
));
asm
(
"v_add_f32 %0, %2, %3 row_shl:10 bound_ctrl:0 "
:
"=v"
(
accm0
)
:
"0"
(
accm0
),
"v"
(
sum
[
n
][
y
][
6
]),
"v"
(
accm0
));
asm
(
"v_add_f32 %0, %2, %3 row_shl:10 bound_ctrl:0 "
:
"=v"
(
accm16
)
:
"0"
(
accm16
),
"v"
(
sum
[
n
][
y
][
14
]),
"v"
(
accm16
));
asm
(
"v_add_f32 %0, %2, %3 row_shl:11 bound_ctrl:0 "
:
"=v"
(
accm0
)
:
"0"
(
accm0
),
"v"
(
sum
[
n
][
y
][
7
]),
"v"
(
accm0
));
asm
(
"v_add_f32 %0, %2, %3 row_shl:11 bound_ctrl:0 "
:
"=v"
(
accm16
)
:
"0"
(
accm16
),
"v"
(
sum
[
n
][
y
][
15
]),
"v"
(
accm16
));
accm0
+=
__shfl
(
accm0
,
36
);
accm16
+=
__shfl
(
accm16
,
52
);
sum
[
n
][
y
][
0
]
=
accm0
+
__shfl
(
accm16
,
16
);
}
}
if
(
threadIdx
.
x
==
0
)
{
for
(
int
n
=
0
;
n
<
N
;
n
++
)
{
for
(
int
y
=
0
;
y
<
YTILE
;
y
++
)
{
if
(
y
+
m
>=
M
)
break
;
// To avoid mem access fault.
C
[
m
+
y
+
n
*
M
]
=
__float2s
<
scalar_t
>
(
sum
[
n
][
y
][
0
]
*
sA
*
sB
);
}
}
}
m
+=
CuCount
*
_WvPrGrp
*
YTILE
;
}
}
#else // !defined(__HIP__MI300__) TODO: Add NAVI support
template
<
typename
scalar_t
,
typename
fp8_t
,
int
THRDS
,
int
YTILE
,
int
WvPrGrp
,
int
A_CHUNK
,
int
UNRL
,
int
N
>
__global__
void
wvSplitKQ_hf_
(
const
int
K
,
const
int
Kp
,
const
int
M
,
const
fp8_t
*
B
,
const
fp8_t
*
__restrict__
A
,
scalar_t
*
C
,
const
float
*
__restrict__
s_A
,
const
float
*
__restrict__
s_B
,
const
int
_WvPrGrp
,
const
int
CuCount
)
{
UNREACHABLE_CODE
}
#endif // defined(__HIP__MI300__) TODO: Add NAVI support
void
wvSplitKQ
(
at
::
Tensor
&
in_a
,
at
::
Tensor
&
in_b
,
at
::
Tensor
&
out_c
,
at
::
Tensor
&
scale_a
,
at
::
Tensor
&
scale_b
,
const
int64_t
CuCount
)
{
static
c10
::
ScalarType
kFp8Type
=
is_fp8_ocp
()
?
c10
::
ScalarType
::
Float8_e4m3fn
:
c10
::
ScalarType
::
Float8_e4m3fnuz
;
auto
M_in
=
in_a
.
size
(
0
);
auto
K_in
=
in_a
.
size
(
1
);
auto
N_in
=
in_b
.
size
(
0
);
auto
Kp_in
=
in_a
.
stride
(
0
);
TORCH_CHECK
(
K_in
%
16
==
0
,
"k % 16 == 0"
);
TORCH_CHECK
(
in_a
.
dtype
()
==
in_b
.
dtype
()
&&
in_a
.
dtype
()
==
kFp8Type
);
TORCH_CHECK
(
out_c
.
dtype
()
==
torch
::
kFloat16
||
out_c
.
dtype
()
==
torch
::
kBFloat16
);
dim3
grid
(
CuCount
);
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
in_a
));
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
#define WVSPLITKQ(_WvPrGrp, _YTILEs, _YTILEm, _YTILEb, _UNRLs, _UNRLm, _UNRLb, \
_N) \
{ \
dim3 block(64, _WvPrGrp); \
if ((K_in * N_in <= 64 * 1024) && (M_in % _YTILEs == 0)) { \
int __wvPrGrp = mindiv(M_in, CuCount * _YTILEs, _WvPrGrp); \
wvSplitKQ_hf_sml_<fptype, fp8_t, 64, _YTILEs, _WvPrGrp, 16, _UNRLs, _N> \
<<<grid, block, 0, stream>>>(K_in, Kp_in, M_in, a_ptr, b_ptr, c_ptr, \
s_a, s_b, __wvPrGrp, CuCount); \
} else { \
int __wvPrGrp = mindiv(M_in, CuCount * _YTILEm, _WvPrGrp); \
wvSplitKQ_hf_<fptype, fp8_t, 64, _YTILEm, _WvPrGrp, 16, _UNRLm, _N> \
<<<grid, block, 0, stream>>>(K_in, Kp_in, M_in, a_ptr, b_ptr, c_ptr, \
s_a, s_b, __wvPrGrp, CuCount); \
} \
}
AT_DISPATCH_REDUCED_FLOATING_TYPES
(
out_c
.
scalar_type
(),
"wvSplitKQ"
,
[
&
]
{
using
fptype
=
typename
scalar
<
scalar_t
>::
type
;
auto
c_ptr
=
reinterpret_cast
<
fptype
*>
(
out_c
.
data_ptr
());
auto
s_a
=
scale_a
.
data_ptr
<
float
>
();
auto
s_b
=
scale_b
.
data_ptr
<
float
>
();
VLLM_DISPATCH_FP8_TYPES
(
in_a
.
scalar_type
(),
"wvSplitKQ"
,
[
&
]
{
auto
a_ptr
=
in_a
.
data_ptr
<
fp8_t
>
();
auto
b_ptr
=
in_b
.
data_ptr
<
fp8_t
>
();
switch
(
N_in
)
{
case
1
:
WVSPLITKQ
(
16
,
2
,
2
,
2
,
2
,
2
,
2
,
1
)
break
;
case
2
:
WVSPLITKQ
(
16
,
2
,
2
,
2
,
2
,
2
,
2
,
2
)
break
;
case
3
:
WVSPLITKQ
(
16
,
4
,
7
,
7
,
1
,
1
,
1
,
3
)
break
;
case
4
:
WVSPLITKQ
(
16
,
4
,
7
,
7
,
1
,
1
,
1
,
4
)
break
;
default:
throw
std
::
runtime_error
(
"Unsupported N value: "
+
std
::
to_string
(
M_in
)
+
","
+
std
::
to_string
(
K_in
)
+
","
+
std
::
to_string
(
N_in
));
}
});
});
}
csrc/rocm/torch_bindings.cpp
View file @
dcb5624a
...
@@ -14,6 +14,24 @@
...
@@ -14,6 +14,24 @@
TORCH_LIBRARY_EXPAND
(
TORCH_EXTENSION_NAME
,
rocm_ops
)
{
TORCH_LIBRARY_EXPAND
(
TORCH_EXTENSION_NAME
,
rocm_ops
)
{
// vLLM custom ops for rocm
// vLLM custom ops for rocm
// Custom gemm op for matrix-vector multiplication
rocm_ops
.
def
(
"LLMM1(Tensor in_a, Tensor in_b, int rows_per_block) -> "
"Tensor"
);
rocm_ops
.
impl
(
"LLMM1"
,
torch
::
kCUDA
,
&
LLMM1
);
// Custom gemm op for skinny matrix-matrix multiplication
rocm_ops
.
def
(
"wvSplitK(Tensor in_a, Tensor in_b, int CuCount) -> "
"Tensor"
);
rocm_ops
.
impl
(
"wvSplitK"
,
torch
::
kCUDA
,
&
wvSplitK
);
// wvSplitK for fp8
rocm_ops
.
def
(
"wvSplitKQ(Tensor in_a, Tensor in_b, Tensor! out_c, Tensor scale_a, "
" Tensor scale_b, int CuCount) -> ()"
);
rocm_ops
.
impl
(
"wvSplitKQ"
,
torch
::
kCUDA
,
&
wvSplitKQ
);
// Custom attention op
// Custom attention op
// Compute the attention between an input query and the cached
// Compute the attention between an input query and the cached
// keys/values using PagedAttention.
// keys/values using PagedAttention.
...
...
csrc/torch_bindings.cpp
View file @
dcb5624a
...
@@ -294,6 +294,13 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
...
@@ -294,6 +294,13 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
") -> ()"
);
") -> ()"
);
ops
.
impl
(
"advance_step_flashinfer"
,
torch
::
kCUDA
,
&
advance_step_flashinfer
);
ops
.
impl
(
"advance_step_flashinfer"
,
torch
::
kCUDA
,
&
advance_step_flashinfer
);
// Compute MLA decode using cutlass.
// ops.def(
// "cutlass_mla_decode(Tensor! out, Tensor q_nope, Tensor q_pe,"
// " Tensor kv_c_and_k_pe_cache, Tensor seq_lens,"
// " Tensor page_table, float scale) -> ()");
// ops.impl("cutlass_mla_decode", torch::kCUDA, &cutlass_mla_decode);
// Layernorm
// Layernorm
// Apply Root Mean Square (RMS) Normalization to the input tensor.
// Apply Root Mean Square (RMS) Normalization to the input tensor.
ops
.
def
(
ops
.
def
(
...
...
docker/Dockerfile
View file @
dcb5624a
...
@@ -162,6 +162,9 @@ ENV UV_HTTP_TIMEOUT=500
...
@@ -162,6 +162,9 @@ ENV UV_HTTP_TIMEOUT=500
COPY
requirements/lint.txt requirements/lint.txt
COPY
requirements/lint.txt requirements/lint.txt
COPY
requirements/test.txt requirements/test.txt
COPY
requirements/test.txt requirements/test.txt
COPY
requirements/dev.txt requirements/dev.txt
COPY
requirements/dev.txt requirements/dev.txt
# Workaround for #17068
RUN
--mount
=
type
=
cache,target
=
/root/.cache/uv
\
uv pip
install
--system
mamba-ssm
==
2.2.4
--no-build-isolation
RUN
--mount
=
type
=
cache,target
=
/root/.cache/uv
\
RUN
--mount
=
type
=
cache,target
=
/root/.cache/uv
\
uv pip
install
--system
-r
requirements/dev.txt
uv pip
install
--system
-r
requirements/dev.txt
#################### DEV IMAGE ####################
#################### DEV IMAGE ####################
...
@@ -240,6 +243,8 @@ if [ "$TARGETPLATFORM" != "linux/arm64" ]; then \
...
@@ -240,6 +243,8 @@ if [ "$TARGETPLATFORM" != "linux/arm64" ]; then \
uv pip
install
--system
https://github.com/flashinfer-ai/flashinfer/releases/download/v0.2.1.post2/flashinfer_python-0.2.1.post2+cu124torch2.6-cp38-abi3-linux_x86_64.whl
;
\
uv pip
install
--system
https://github.com/flashinfer-ai/flashinfer/releases/download/v0.2.1.post2/flashinfer_python-0.2.1.post2+cu124torch2.6-cp38-abi3-linux_x86_64.whl
;
\
fi
fi
COPY
examples examples
COPY
examples examples
COPY
benchmarks benchmarks
COPY
./vllm/collect_env.py .
# Although we build Flashinfer with AOT mode, there's still
# Although we build Flashinfer with AOT mode, there's still
# some issues w.r.t. JIT compilation. Therefore we need to
# some issues w.r.t. JIT compilation. Therefore we need to
...
@@ -263,6 +268,9 @@ ADD . /vllm-workspace/
...
@@ -263,6 +268,9 @@ ADD . /vllm-workspace/
ENV
UV_HTTP_TIMEOUT=500
ENV
UV_HTTP_TIMEOUT=500
# install development dependencies (for testing)
# install development dependencies (for testing)
# Workaround for #17068
RUN
--mount
=
type
=
cache,target
=
/root/.cache/uv
\
uv pip
install
--system
mamba-ssm
==
2.2.4
--no-build-isolation
RUN
--mount
=
type
=
cache,target
=
/root/.cache/uv
\
RUN
--mount
=
type
=
cache,target
=
/root/.cache/uv
\
uv pip
install
--system
-r
requirements/dev.txt
uv pip
install
--system
-r
requirements/dev.txt
...
@@ -289,6 +297,7 @@ RUN mv vllm test_docs/
...
@@ -289,6 +297,7 @@ RUN mv vllm test_docs/
#################### OPENAI API SERVER ####################
#################### OPENAI API SERVER ####################
# base openai image with additional requirements, for any subsequent openai-style images
# base openai image with additional requirements, for any subsequent openai-style images
FROM
vllm-base AS vllm-openai-base
FROM
vllm-base AS vllm-openai-base
ARG
TARGETPLATFORM
# This timeout (in seconds) is necessary when installing some dependencies via uv since it's likely to time out
# This timeout (in seconds) is necessary when installing some dependencies via uv since it's likely to time out
# Reference: https://github.com/astral-sh/uv/pull/1694
# Reference: https://github.com/astral-sh/uv/pull/1694
...
...
docker/Dockerfile.cpu
View file @
dcb5624a
...
@@ -121,6 +121,7 @@ RUN --mount=type=cache,target=/root/.cache/uv \
...
@@ -121,6 +121,7 @@ RUN --mount=type=cache,target=/root/.cache/uv \
ADD ./tests/ ./tests/
ADD ./tests/ ./tests/
ADD ./examples/ ./examples/
ADD ./examples/ ./examples/
ADD ./benchmarks/ ./benchmarks/
ADD ./benchmarks/ ./benchmarks/
ADD ./vllm/collect_env.py .
# install development dependencies (for testing)
# install development dependencies (for testing)
RUN --mount=type=cache,target=/root/.cache/uv \
RUN --mount=type=cache,target=/root/.cache/uv \
...
...
docker/Dockerfile.nightly_torch
0 → 100644
View file @
dcb5624a
# The vLLM Dockerfile is used to construct vLLM image against torch nightly that can be directly used for testing
# for torch nightly, cuda >=12.6 is required,
# use 12.8 due to FlashAttention issue with cuda 12.6 (https://github.com/vllm-project/vllm/issues/15435#issuecomment-2775924628)
ARG CUDA_VERSION=12.8.0
#
#################### BASE BUILD IMAGE ####################
# prepare basic build environment
FROM nvidia/cuda:${CUDA_VERSION}-devel-ubuntu20.04 AS base
ARG CUDA_VERSION=12.8.0
ARG PYTHON_VERSION=3.12
ARG TARGETPLATFORM
ENV DEBIAN_FRONTEND=noninteractive
# Install Python and other dependencies
RUN echo 'tzdata tzdata/Areas select America' | debconf-set-selections \
&& echo 'tzdata tzdata/Zones/America select Los_Angeles' | debconf-set-selections \
&& apt-get update -y \
&& apt-get install -y ccache software-properties-common git curl sudo \
&& add-apt-repository ppa:deadsnakes/ppa \
&& apt-get update -y \
&& apt-get install -y python${PYTHON_VERSION} python${PYTHON_VERSION}-dev python${PYTHON_VERSION}-venv \
&& update-alternatives --install /usr/bin/python3 python3 /usr/bin/python${PYTHON_VERSION} 1 \
&& update-alternatives --set python3 /usr/bin/python${PYTHON_VERSION} \
&& ln -sf /usr/bin/python${PYTHON_VERSION}-config /usr/bin/python3-config \
&& curl -sS https://bootstrap.pypa.io/get-pip.py | python${PYTHON_VERSION} \
&& python3 --version \
&& python3 -m pip --version
# Install uv for faster pip installs
RUN --mount=type=cache,target=/root/.cache/uv \
python3 -m pip install uv
# This timeout (in seconds) is necessary when installing some dependencies via uv since it's likely to time out
# Reference: https://github.com/astral-sh/uv/pull/1694
ENV UV_HTTP_TIMEOUT=500
# Upgrade to GCC 10 to avoid https://gcc.gnu.org/bugzilla/show_bug.cgi?id=92519
# as it was causing spam when compiling the CUTLASS kernels
RUN apt-get install -y gcc-10 g++-10
RUN update-alternatives --install /usr/bin/gcc gcc /usr/bin/gcc-10 110 --slave /usr/bin/g++ g++ /usr/bin/g++-10
RUN <<EOF
gcc --version
EOF
# Workaround for https://github.com/openai/triton/issues/2507 and
# https://github.com/pytorch/pytorch/issues/107960 -- hopefully
# this won't be needed for future versions of this docker image
# or future versions of triton.
RUN ldconfig /usr/local/cuda-$(echo $CUDA_VERSION | cut -d. -f1,2)/compat/
WORKDIR /workspace
# install build and runtime dependencies
COPY requirements/common.txt requirements/common.txt
COPY use_existing_torch.py use_existing_torch.py
COPY pyproject.toml pyproject.toml
# install build and runtime dependencies without stable torch version
RUN python3 use_existing_torch.py
# install torch nightly
ARG PINNED_TORCH_VERSION
RUN --mount=type=cache,target=/root/.cache/uv \
if [ -n "$PINNED_TORCH_VERSION" ]; then \
pkgs="$PINNED_TORCH_VERSION"; \
else \
pkgs="torch torchaudio torchvision"; \
fi && \
uv pip install --system $pkgs --index-url https://download.pytorch.org/whl/nightly/cu128
RUN --mount=type=cache,target=/root/.cache/uv \
uv pip install --system numba==0.61.2
RUN --mount=type=cache,target=/root/.cache/uv \
uv pip install --system -r requirements/common.txt
# must put before installing xformers, so it can install the correct version of xfomrers.
ARG torch_cuda_arch_list='8.0;8.6;8.9;9.0'
ENV TORCH_CUDA_ARCH_LIST=${torch_cuda_arch_list}
# Build xformers with cuda and torch nightly
# following official xformers guidance: https://github.com/facebookresearch/xformers#build
# todo(elainewy): cache xformers build result for faster build
ARG max_jobs=16
ENV MAX_JOBS=${max_jobs}
ARG XFORMERS_COMMIT=f2de641ef670510cadab099ce6954031f52f191c
ENV CCACHE_DIR=/root/.cache/ccache
RUN --mount=type=cache,target=/root/.cache/ccache \
--mount=type=cache,target=/root/.cache/uv \
echo 'git clone xformers...' \
&& git clone https://github.com/facebookresearch/xformers.git --recursive \
&& cd xformers \
&& git checkout ${XFORMERS_COMMIT} \
&& git submodule update --init --recursive \
&& echo 'finish git clone xformers...' \
&& rm -rf build \
&& python3 setup.py bdist_wheel --dist-dir=../xformers-dist --verbose \
&& cd .. \
&& rm -rf xformers
RUN --mount=type=cache,target=/root/.cache/uv \
uv pip install --system xformers-dist/*.whl --verbose
# build can take a long time, and the torch nightly version fetched from url can be different in next docker stage.
# track the nightly torch version used in the build, when we set up runtime environment we can make sure the version is the same
RUN uv pip freeze | grep -i '^torch\|^torchvision\|^torchaudio' > torch_build_versions.txt
RUN cat torch_build_versions.txt
# cuda arch list used by torch
# can be useful for `test`
# explicitly set the list to avoid issues with torch 2.2
# see https://github.com/pytorch/pytorch/pull/123243
# Override the arch list for flash-attn to reduce the binary size
ARG vllm_fa_cmake_gpu_arches='80-real;90-real'
ENV VLLM_FA_CMAKE_GPU_ARCHES=${vllm_fa_cmake_gpu_arches}
#################### BASE BUILD IMAGE ####################
#################### WHEEL BUILD IMAGE ####################
FROM base AS build
ARG TARGETPLATFORM
# This timeout (in seconds) is necessary when installing some dependencies via uv since it's likely to time out
# Reference: https://github.com/astral-sh/uv/pull/1694
ENV UV_HTTP_TIMEOUT=500
COPY . .
RUN python3 use_existing_torch.py
RUN --mount=type=cache,target=/root/.cache/uv \
uv pip install --system -r requirements/build.txt
ARG GIT_REPO_CHECK=0
RUN --mount=type=bind,source=.git,target=.git \
if [ "$GIT_REPO_CHECK" != "0" ]; then bash tools/check_repo.sh ; fi
# Max jobs used by Ninja to build extensions
ARG max_jobs=16
ENV MAX_JOBS=${max_jobs}
ARG nvcc_threads=2
ENV NVCC_THREADS=$nvcc_threads
ARG USE_SCCACHE
ARG SCCACHE_BUCKET_NAME=vllm-build-sccache
ARG SCCACHE_REGION_NAME=us-west-2
ARG SCCACHE_S3_NO_CREDENTIALS=0
# if USE_SCCACHE is set, use sccache to speed up compilation
RUN --mount=type=cache,target=/root/.cache/uv \
--mount=type=bind,source=.git,target=.git \
if [ "$USE_SCCACHE" = "1" ]; then \
echo "Installing sccache..." \
&& curl -L -o sccache.tar.gz https://github.com/mozilla/sccache/releases/download/v0.8.1/sccache-v0.8.1-x86_64-unknown-linux-musl.tar.gz \
&& tar -xzf sccache.tar.gz \
&& sudo mv sccache-v0.8.1-x86_64-unknown-linux-musl/sccache /usr/bin/sccache \
&& rm -rf sccache.tar.gz sccache-v0.8.1-x86_64-unknown-linux-musl \
&& export SCCACHE_BUCKET=${SCCACHE_BUCKET_NAME} \
&& export SCCACHE_REGION=${SCCACHE_REGION_NAME} \
&& export SCCACHE_S3_NO_CREDENTIALS=${SCCACHE_S3_NO_CREDENTIALS} \
&& export SCCACHE_IDLE_TIMEOUT=0 \
&& export CMAKE_BUILD_TYPE=Release \
&& sccache --show-stats \
&& python3 setup.py bdist_wheel --dist-dir=dist --py-limited-api=cp38 \
&& sccache --show-stats; \
fi
ENV CCACHE_DIR=/root/.cache/ccache
RUN --mount=type=cache,target=/root/.cache/ccache \
--mount=type=cache,target=/root/.cache/uv \
--mount=type=bind,source=.git,target=.git \
if [ "$USE_SCCACHE" != "1" ]; then \
# Clean any existing CMake artifacts
rm -rf .deps && \
mkdir -p .deps && \
python3 setup.py bdist_wheel --dist-dir=dist --py-limited-api=cp38; \
fi
#################### WHEEL BUILD IMAGE ####################
################### VLLM INSTALLED IMAGE ####################
# Setup clean environment for vLLM and its dependencies for test and api server using ubuntu22.04 with AOT flashinfer
FROM nvidia/cuda:${CUDA_VERSION}-devel-ubuntu22.04 AS vllm-base
# prepare for environment starts
ARG CUDA_VERSION=12.8.0
ARG PYTHON_VERSION=3.12
WORKDIR /vllm-workspace
ENV DEBIAN_FRONTEND=noninteractive
ARG TARGETPLATFORM
RUN PYTHON_VERSION_STR=$(echo ${PYTHON_VERSION} | sed 's/\.//g') && \
echo "export PYTHON_VERSION_STR=${PYTHON_VERSION_STR}" >> /etc/environment
# Install Python and other dependencies
RUN echo 'tzdata tzdata/Areas select America' | debconf-set-selections \
&& echo 'tzdata tzdata/Zones/America select Los_Angeles' | debconf-set-selections \
&& apt-get update -y \
&& apt-get install -y ccache software-properties-common git curl wget sudo vim python3-pip \
&& apt-get install -y ffmpeg libsm6 libxext6 libgl1 \
&& add-apt-repository ppa:deadsnakes/ppa \
&& apt-get update -y \
&& apt-get install -y python${PYTHON_VERSION} python${PYTHON_VERSION}-dev python${PYTHON_VERSION}-venv libibverbs-dev \
&& update-alternatives --install /usr/bin/python3 python3 /usr/bin/python${PYTHON_VERSION} 1 \
&& update-alternatives --set python3 /usr/bin/python${PYTHON_VERSION} \
&& ln -sf /usr/bin/python${PYTHON_VERSION}-config /usr/bin/python3-config \
&& curl -sS https://bootstrap.pypa.io/get-pip.py | python${PYTHON_VERSION} \
&& python3 --version && python3 -m pip --version
RUN --mount=type=cache,target=/root/.cache/uv \
python3 -m pip install uv
# This timeout (in seconds) is necessary when installing some dependencies via uv since it's likely to time out
# Reference: https://github.com/astral-sh/uv/pull/1694
ENV UV_HTTP_TIMEOUT=500
# Workaround for https://github.com/openai/triton/issues/2507 and
# https://github.com/pytorch/pytorch/issues/107960 -- hopefully
# this won't be needed for future versions of this docker image
# or future versions of triton.
RUN ldconfig /usr/local/cuda-$(echo $CUDA_VERSION | cut -d. -f1,2)/compat/
# get the nightly torch version used in the build to make sure the version is the same
COPY --from=base /workspace/torch_build_versions.txt ./torch_build_versions.txt
RUN --mount=type=cache,target=/root/.cache/uv \
uv pip install --system $(cat torch_build_versions.txt | xargs) --index-url https://download.pytorch.org/whl/nightly/cu128
# install the vllm wheel
RUN --mount=type=bind,from=build,src=/workspace/dist,target=/vllm-workspace/vllm-dist \
--mount=type=cache,target=/root/.cache/uv \
uv pip install --system vllm-dist/*.whl --verbose
# install xformers again for the new environment
RUN --mount=type=bind,from=base,src=/workspace/xformers-dist,target=/vllm-workspace/xformers-dist \
--mount=type=cache,target=/root/.cache/uv \
uv pip install --system /vllm-workspace/xformers-dist/*.whl --verbose
ARG torch_cuda_arch_list='8.0;8.6;8.9;9.0'
# install package for build flashinfer
# see issue: https://github.com/flashinfer-ai/flashinfer/issues/738
RUN pip install setuptools==75.6.0 packaging==23.2 ninja==1.11.1.3 build==1.2.2.post1
# build flashinfer for torch nightly from source around 10 mins
# release version: v0.2.2.post1
# todo(elainewy): cache flashinfer build result for faster build
ENV CCACHE_DIR=/root/.cache/ccache
RUN --mount=type=cache,target=/root/.cache/ccache \
--mount=type=cache,target=/root/.cache/uv \
echo "git clone flashinfer..." \
&& git clone --recursive https://github.com/flashinfer-ai/flashinfer.git \
&& cd flashinfer \
&& git checkout v0.2.2.post1 \
&& git submodule update --init --recursive \
&& echo "finish git clone flashinfer..." \
&& rm -rf build \
&& export TORCH_CUDA_ARCH_LIST=${torch_cuda_arch_list} \
&& FLASHINFER_ENABLE_AOT=1 python3 setup.py bdist_wheel --dist-dir=../flashinfer-dist --verbose \
&& cd .. \
&& rm -rf flashinfer
# install flashinfer
RUN --mount=type=cache,target=/root/.cache/uv \
uv pip install --system flashinfer-dist/*.whl --verbose
# install common packages
COPY requirements/common.txt requirements/common.txt
COPY use_existing_torch.py use_existing_torch.py
COPY pyproject.toml pyproject.toml
COPY examples examples
COPY benchmarks benchmarks
COPY ./vllm/collect_env.py .
RUN python3 use_existing_torch.py
RUN --mount=type=cache,target=/root/.cache/uv \
uv pip install --system -r requirements/common.txt
################### VLLM INSTALLED IMAGE ####################
#################### UNITTEST IMAGE #############################
FROM vllm-base as test
COPY tests/ tests/
# install build and runtime dependencies without stable torch version
COPY requirements/nightly_torch_test.txt requirements/nightly_torch_test.txt
# This timeout (in seconds) is necessary when installing some dependencies via uv since it's likely to time out
# Reference: https://github.com/astral-sh/uv/pull/1694
ENV UV_HTTP_TIMEOUT=500
# install development dependencies (for testing)
RUN --mount=type=cache,target=/root/.cache/uv \
uv pip install --system -e tests/vllm_test_utils
# enable fast downloads from hf (for testing)
RUN --mount=type=cache,target=/root/.cache/uv \
uv pip install --system hf_transfer
ENV HF_HUB_ENABLE_HF_TRANSFER 1
RUN --mount=type=cache,target=/root/.cache/uv \
uv pip install --system -r requirements/nightly_torch_test.txt
#################### UNITTEST IMAGE #############################
docker/Dockerfile.ppc64le
View file @
dcb5624a
...
@@ -126,13 +126,16 @@ RUN --mount=type=cache,target=/root/.cache/uv \
...
@@ -126,13 +126,16 @@ RUN --mount=type=cache,target=/root/.cache/uv \
FROM base-builder AS cv-builder
FROM base-builder AS cv-builder
ARG MAX_JOBS
ARG MAX_JOBS
ARG OPENCV_VERSION=84
ARG OPENCV_VERSION=86
# patch for version 4.11.0.86
ARG OPENCV_PATCH=97f3f39
ARG ENABLE_HEADLESS=1
ARG ENABLE_HEADLESS=1
RUN --mount=type=cache,target=/root/.cache/uv \
RUN --mount=type=cache,target=/root/.cache/uv \
source /opt/rh/gcc-toolset-13/enable && \
source /opt/rh/gcc-toolset-13/enable && \
git clone --recursive https://github.com/opencv/opencv-python.git -b ${OPENCV_VERSION} && \
git clone --recursive https://github.com/opencv/opencv-python.git -b ${OPENCV_VERSION} && \
cd opencv-python && \
cd opencv-python && \
sed -i 's/"setuptools==59.2.0",/"setuptools<70.0",/g' pyproject.toml && \
sed -i -E -e 's/"setuptools.+",/"setuptools",/g' pyproject.toml && \
cd opencv && git cherry-pick --no-commit $OPENCV_PATCH && cd .. && \
python -m build --wheel --installer=uv --outdir /opencvwheels/
python -m build --wheel --installer=uv --outdir /opencvwheels/
###############################################################
###############################################################
...
@@ -148,9 +151,15 @@ COPY --from=arrow-builder /tmp/control /dev/null
...
@@ -148,9 +151,15 @@ COPY --from=arrow-builder /tmp/control /dev/null
COPY --from=cv-builder /tmp/control /dev/null
COPY --from=cv-builder /tmp/control /dev/null
ARG VLLM_TARGET_DEVICE=cpu
ARG VLLM_TARGET_DEVICE=cpu
ARG GRPC_PYTHON_BUILD_SYSTEM_OPENSSL=1
# this step installs vllm and populates uv cache
# this step installs vllm and populates uv cache
# with all the transitive dependencies
# with all the transitive dependencies
RUN --mount=type=cache,target=/root/.cache/uv \
source /opt/rh/gcc-toolset-13/enable && \
git clone https://github.com/huggingface/xet-core.git && cd xet-core/hf_xet/ && \
uv pip install maturin && \
uv build --wheel --out-dir /hf_wheels/
RUN --mount=type=cache,target=/root/.cache/uv \
RUN --mount=type=cache,target=/root/.cache/uv \
--mount=type=bind,from=torch-builder,source=/torchwheels/,target=/torchwheels/,ro \
--mount=type=bind,from=torch-builder,source=/torchwheels/,target=/torchwheels/,ro \
--mount=type=bind,from=arrow-builder,source=/arrowwheels/,target=/arrowwheels/,ro \
--mount=type=bind,from=arrow-builder,source=/arrowwheels/,target=/arrowwheels/,ro \
...
@@ -159,7 +168,7 @@ RUN --mount=type=cache,target=/root/.cache/uv \
...
@@ -159,7 +168,7 @@ RUN --mount=type=cache,target=/root/.cache/uv \
source /opt/rh/gcc-toolset-13/enable && \
source /opt/rh/gcc-toolset-13/enable && \
uv pip install /opencvwheels/*.whl /arrowwheels/*.whl /torchwheels/*.whl && \
uv pip install /opencvwheels/*.whl /arrowwheels/*.whl /torchwheels/*.whl && \
sed -i -e 's/.*torch.*//g' /src/pyproject.toml /src/requirements/*.txt && \
sed -i -e 's/.*torch.*//g' /src/pyproject.toml /src/requirements/*.txt && \
uv pip install pandas pythran pybind11 && \
uv pip install pandas pythran pybind11
/hf_wheels/*.whl
&& \
# sentencepiece.pc is in some pkgconfig inside uv cache
# sentencepiece.pc is in some pkgconfig inside uv cache
export PKG_CONFIG_PATH=$(find / -type d -name "pkgconfig" 2>/dev/null | tr '\n' ':') && \
export PKG_CONFIG_PATH=$(find / -type d -name "pkgconfig" 2>/dev/null | tr '\n' ':') && \
uv pip install -r /src/requirements/common.txt -r /src/requirements/cpu.txt -r /src/requirements/build.txt --no-build-isolation && \
uv pip install -r /src/requirements/common.txt -r /src/requirements/cpu.txt -r /src/requirements/build.txt --no-build-isolation && \
...
@@ -247,8 +256,9 @@ RUN --mount=type=cache,target=/root/.cache/uv \
...
@@ -247,8 +256,9 @@ RUN --mount=type=cache,target=/root/.cache/uv \
--mount=type=bind,from=torch-builder,source=/torchwheels/,target=/torchwheels/,ro \
--mount=type=bind,from=torch-builder,source=/torchwheels/,target=/torchwheels/,ro \
--mount=type=bind,from=arrow-builder,source=/arrowwheels/,target=/arrowwheels/,ro \
--mount=type=bind,from=arrow-builder,source=/arrowwheels/,target=/arrowwheels/,ro \
--mount=type=bind,from=cv-builder,source=/opencvwheels/,target=/opencvwheels/,ro \
--mount=type=bind,from=cv-builder,source=/opencvwheels/,target=/opencvwheels/,ro \
--mount=type=bind,from=vllmcache-builder,source=/hf_wheels/,target=/hf_wheels/,ro \
--mount=type=bind,from=vllmcache-builder,source=/vllmwheel/,target=/vllmwheel/,ro \
--mount=type=bind,from=vllmcache-builder,source=/vllmwheel/,target=/vllmwheel/,ro \
HOME=/root uv pip install /opencvwheels/*.whl /arrowwheels/*.whl /torchwheels/*.whl /vllmwheel/*.whl
HOME=/root uv pip install /opencvwheels/*.whl /arrowwheels/*.whl /torchwheels/*.whl
/hf_wheels/*.whl
/vllmwheel/*.whl
COPY ./ /workspace/vllm
COPY ./ /workspace/vllm
WORKDIR /workspace/vllm
WORKDIR /workspace/vllm
...
...
docker/Dockerfile.rocm_base
View file @
dcb5624a
...
@@ -12,7 +12,7 @@ ARG PYTORCH_REPO="https://github.com/pytorch/pytorch.git"
...
@@ -12,7 +12,7 @@ ARG PYTORCH_REPO="https://github.com/pytorch/pytorch.git"
ARG PYTORCH_VISION_REPO="https://github.com/pytorch/vision.git"
ARG PYTORCH_VISION_REPO="https://github.com/pytorch/vision.git"
ARG FA_BRANCH="1a7f4dfa"
ARG FA_BRANCH="1a7f4dfa"
ARG FA_REPO="https://github.com/Dao-AILab/flash-attention.git"
ARG FA_REPO="https://github.com/Dao-AILab/flash-attention.git"
ARG AITER_BRANCH="
8970b25b
"
ARG AITER_BRANCH="
7e1ed08
"
ARG AITER_REPO="https://github.com/ROCm/aiter.git"
ARG AITER_REPO="https://github.com/ROCm/aiter.git"
FROM ${BASE_IMAGE} AS base
FROM ${BASE_IMAGE} AS base
...
...
docker/Dockerfile.s390x
View file @
dcb5624a
...
@@ -58,7 +58,7 @@ RUN --mount=type=cache,target=/root/.cache/uv \
...
@@ -58,7 +58,7 @@ RUN --mount=type=cache,target=/root/.cache/uv \
cd ../../python && \
cd ../../python && \
export PYARROW_PARALLEL=4 && \
export PYARROW_PARALLEL=4 && \
export ARROW_BUILD_TYPE=release && \
export ARROW_BUILD_TYPE=release && \
uv pip install -r requirements
/
build.txt && \
uv pip install -r requirements
-
build.txt && \
python setup.py build_ext --build-type=$ARROW_BUILD_TYPE --bundle-arrow-cpp bdist_wheel
python setup.py build_ext --build-type=$ARROW_BUILD_TYPE --bundle-arrow-cpp bdist_wheel
FROM python-install AS numa-build
FROM python-install AS numa-build
...
@@ -96,6 +96,22 @@ RUN --mount=type=cache,target=/root/.cache/uv \
...
@@ -96,6 +96,22 @@ RUN --mount=type=cache,target=/root/.cache/uv \
uv pip install -v torch==${TORCH_VERSION} --extra-index-url https://download.pytorch.org/whl/nightly/cpu && \
uv pip install -v torch==${TORCH_VERSION} --extra-index-url https://download.pytorch.org/whl/nightly/cpu && \
python setup.py bdist_wheel
python setup.py bdist_wheel
FROM python-install AS hf-xet-builder
# Install hf-xet
WORKDIR /tmp
ENV CARGO_HOME=/root/.cargo
ENV RUSTUP_HOME=/root/.rustup
ENV PATH="$CARGO_HOME/bin:$RUSTUP_HOME/bin:$PATH"
RUN --mount=type=cache,target=/root/.cache/uv \
--mount=type=bind,from=rust,source=/root/.cargo,target=/root/.cargo,rw \
--mount=type=bind,from=rust,source=/root/.rustup,target=/root/.rustup,rw \
git clone https://github.com/huggingface/xet-core.git && \
cd xet-core/hf_xet/ && \
uv pip install maturin patchelf && \
python -m maturin build --release --out dist && \
mkdir -p /tmp/hf-xet/dist && \
cp dist/*.whl /tmp/hf-xet/dist/
# Final build stage
# Final build stage
FROM python-install AS vllm-cpu
FROM python-install AS vllm-cpu
ARG PYTHON_VERSION
ARG PYTHON_VERSION
...
@@ -120,12 +136,15 @@ RUN --mount=type=cache,target=/root/.cache/uv \
...
@@ -120,12 +136,15 @@ RUN --mount=type=cache,target=/root/.cache/uv \
--mount=type=bind,from=rust,source=/root/.rustup,target=/root/.rustup,rw \
--mount=type=bind,from=rust,source=/root/.rustup,target=/root/.rustup,rw \
--mount=type=bind,from=pyarrow,source=/tmp/arrow/python/dist,target=/tmp/arrow-wheels \
--mount=type=bind,from=pyarrow,source=/tmp/arrow/python/dist,target=/tmp/arrow-wheels \
--mount=type=bind,from=torch-vision,source=/tmp/vision/dist,target=/tmp/vision-wheels/ \
--mount=type=bind,from=torch-vision,source=/tmp/vision/dist,target=/tmp/vision-wheels/ \
--mount=type=bind,from=hf-xet-builder,source=/tmp/hf-xet/dist,target=/tmp/hf-xet-wheels/ \
sed -i '/^torch/d' requirements/build.txt && \
sed -i '/^torch/d' requirements/build.txt && \
ARROW_WHL_FILE=$(ls /tmp/arrow-wheels/pyarrow-*.whl | head -n 1) && \
ARROW_WHL_FILE=$(ls /tmp/arrow-wheels/pyarrow-*.whl | head -n 1) && \
VISION_WHL_FILE=$(ls /tmp/vision-wheels/*.whl | head -n 1) && \
VISION_WHL_FILE=$(ls /tmp/vision-wheels/*.whl | head -n 1) && \
HF_XET_WHL_FILE=$(ls /tmp/hf-xet-wheels/*.whl | head -n 1) && \
uv pip install -v \
uv pip install -v \
$ARROW_WHL_FILE \
$ARROW_WHL_FILE \
$VISION_WHL_FILE \
$VISION_WHL_FILE \
$HF_XET_WHL_FILE \
--extra-index-url https://download.pytorch.org/whl/nightly/cpu \
--extra-index-url https://download.pytorch.org/whl/nightly/cpu \
--index-strategy unsafe-best-match \
--index-strategy unsafe-best-match \
-r requirements/build.txt \
-r requirements/build.txt \
...
@@ -149,4 +168,5 @@ USER 2000
...
@@ -149,4 +168,5 @@ USER 2000
WORKDIR /home/vllm
WORKDIR /home/vllm
# Set the default entrypoint
# Set the default entrypoint
ENTRYPOINT ["python", "-m", "vllm.entrypoints.openai.api_server"]
ENTRYPOINT ["python", "-m", "vllm.entrypoints.openai.api_server"]
\ No newline at end of file
docs/source/assets/deployment/anything-llm-chat-with-doc.png
0 → 100644
View file @
dcb5624a
118 KB
docs/source/assets/deployment/anything-llm-chat-without-doc.png
0 → 100644
View file @
dcb5624a
136 KB
docs/source/assets/deployment/anything-llm-provider.png
0 → 100644
View file @
dcb5624a
110 KB
Prev
1
2
3
4
5
6
7
8
…
28
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment