Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
fcfc474d
"examples/vscode:/vscode.git/clone" did not exist on "bb94d2e5b14cd3d4e62438491e0ab569c5436cf4"
Commit
fcfc474d
authored
Apr 09, 2025
by
zhuwenwen
Browse files
Merge tag 'v0.8.3' into v0.8.3-dev
parents
bb94d2e5
296c6572
Changes
503
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1359 additions
and
180 deletions
+1359
-180
csrc/cpu/torch_bindings.cpp
csrc/cpu/torch_bindings.cpp
+63
-0
csrc/cpu/utils.cpp
csrc/cpu/utils.cpp
+1
-1
csrc/cuda_view.cu
csrc/cuda_view.cu
+39
-0
csrc/custom_all_reduce.cu
csrc/custom_all_reduce.cu
+31
-27
csrc/custom_all_reduce.cuh
csrc/custom_all_reduce.cuh
+54
-50
csrc/custom_all_reduce_test.cu
csrc/custom_all_reduce_test.cu
+43
-28
csrc/cutlass_extensions/common.hpp
csrc/cutlass_extensions/common.hpp
+11
-1
csrc/cutlass_extensions/epilogue/broadcast_load_epilogue_array_c3x.hpp
...extensions/epilogue/broadcast_load_epilogue_array_c3x.hpp
+457
-0
csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp
csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp
+66
-0
csrc/ops.h
csrc/ops.h
+22
-6
csrc/quantization/cutlass_w8a8/moe/get_group_starts.cuh
csrc/quantization/cutlass_w8a8/moe/get_group_starts.cuh
+80
-0
csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x.cu
csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x.cu
+160
-0
csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x.cuh
csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x.cuh
+149
-0
csrc/quantization/cutlass_w8a8/moe/moe_data.cu
csrc/quantization/cutlass_w8a8/moe/moe_data.cu
+90
-0
csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu
csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu
+67
-0
csrc/quantization/fp8/common.cu
csrc/quantization/fp8/common.cu
+2
-5
csrc/quantization/fp8/common.cuh
csrc/quantization/fp8/common.cuh
+5
-36
csrc/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu
.../fused_kernels/fused_layernorm_dynamic_per_token_quant.cu
+10
-18
csrc/quantization/fused_kernels/layernorm_utils.cuh
csrc/quantization/fused_kernels/layernorm_utils.cuh
+7
-6
csrc/quantization/fused_kernels/quant_conversions.cuh
csrc/quantization/fused_kernels/quant_conversions.cuh
+2
-2
No files found.
csrc/cpu/torch_bindings.cpp
View file @
fcfc474d
...
@@ -18,6 +18,30 @@ void int8_scaled_mm_azp(torch::Tensor& c, const torch::Tensor& a,
...
@@ -18,6 +18,30 @@ void int8_scaled_mm_azp(torch::Tensor& c, const torch::Tensor& a,
const
std
::
optional
<
torch
::
Tensor
>&
azp
,
const
std
::
optional
<
torch
::
Tensor
>&
azp
,
const
std
::
optional
<
torch
::
Tensor
>&
bias
);
const
std
::
optional
<
torch
::
Tensor
>&
bias
);
void
mla_decode_kvcache
(
torch
::
Tensor
&
out
,
torch
::
Tensor
&
query
,
torch
::
Tensor
&
kv_cache
,
double
scale
,
torch
::
Tensor
&
block_tables
,
torch
::
Tensor
&
seq_lens
);
int64_t
init_shm_manager
(
const
std
::
string
&
name
,
const
int64_t
group_size
,
const
int64_t
rank
);
std
::
string
join_shm_manager
(
int64_t
handle
,
const
std
::
string
&
name
);
void
shm_allreduce
(
int64_t
handle
,
torch
::
Tensor
&
data
);
void
shm_gather
(
int64_t
handle
,
torch
::
Tensor
&
data
,
const
std
::
optional
<
std
::
vector
<
torch
::
Tensor
>>&
outputs
,
int64_t
dst
);
void
shm_all_gather
(
int64_t
handle
,
const
torch
::
Tensor
&
data
,
torch
::
Tensor
&
output
);
void
shm_send_tensor_list
(
int64_t
handle
,
const
std
::
vector
<
torch
::
Tensor
>&
tensor_list
,
int64_t
dst
);
std
::
vector
<
torch
::
Tensor
>
shm_recv_tensor_list
(
int64_t
handle
,
int64_t
src
);
TORCH_LIBRARY_EXPAND
(
TORCH_EXTENSION_NAME
,
ops
)
{
TORCH_LIBRARY_EXPAND
(
TORCH_EXTENSION_NAME
,
ops
)
{
// vLLM custom ops
// vLLM custom ops
...
@@ -127,6 +151,29 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
...
@@ -127,6 +151,29 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
" Tensor? azp, Tensor? bias) -> ()"
);
" Tensor? azp, Tensor? bias) -> ()"
);
ops
.
impl
(
"cutlass_scaled_mm_azp"
,
torch
::
kCPU
,
&
int8_scaled_mm_azp
);
ops
.
impl
(
"cutlass_scaled_mm_azp"
,
torch
::
kCPU
,
&
int8_scaled_mm_azp
);
#endif
#endif
// SHM CCL
#ifdef __AVX512F__
ops
.
def
(
"init_shm_manager(str name, int group_size, int rank) -> int"
,
&
init_shm_manager
);
ops
.
def
(
"join_shm_manager(int handle, str name) -> str"
,
&
join_shm_manager
);
ops
.
def
(
"shm_allreduce(int handle, Tensor! data) -> ()"
);
ops
.
impl
(
"shm_allreduce"
,
torch
::
kCPU
,
&
shm_allreduce
);
ops
.
def
(
"shm_gather(int handle, Tensor data, Tensor[](a!)? outputs, int dst) -> "
"()"
);
ops
.
impl
(
"shm_gather"
,
torch
::
kCPU
,
&
shm_gather
);
ops
.
def
(
"shm_all_gather(int handle, Tensor data, Tensor! output) -> "
"()"
);
ops
.
impl
(
"shm_all_gather"
,
torch
::
kCPU
,
&
shm_all_gather
);
ops
.
def
(
"shm_send_tensor_list(int handle, Tensor[](a) tensor_list, int dst) -> "
"()"
);
ops
.
impl
(
"shm_send_tensor_list"
,
torch
::
kCPU
,
&
shm_send_tensor_list
);
ops
.
def
(
"shm_recv_tensor_list(int handle, int src) -> Tensor[](a)"
,
&
shm_recv_tensor_list
);
#endif
}
}
TORCH_LIBRARY_EXPAND
(
CONCAT
(
TORCH_EXTENSION_NAME
,
_cache_ops
),
cache_ops
)
{
TORCH_LIBRARY_EXPAND
(
CONCAT
(
TORCH_EXTENSION_NAME
,
_cache_ops
),
cache_ops
)
{
...
@@ -150,6 +197,14 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) {
...
@@ -150,6 +197,14 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) {
" str kv_cache_dtype,"
" str kv_cache_dtype,"
" Tensor k_scale, Tensor v_scale) -> ()"
);
" Tensor k_scale, Tensor v_scale) -> ()"
);
cache_ops
.
impl
(
"reshape_and_cache"
,
torch
::
kCPU
,
&
reshape_and_cache
);
cache_ops
.
impl
(
"reshape_and_cache"
,
torch
::
kCPU
,
&
reshape_and_cache
);
cache_ops
.
def
(
"concat_and_cache_mla(Tensor kv_c, Tensor k_pe,"
" Tensor! kv_cache,"
" Tensor slot_mapping,"
" str kv_cache_dtype,"
" Tensor scale) -> ()"
);
cache_ops
.
impl
(
"concat_and_cache_mla"
,
torch
::
kCPU
,
&
concat_and_cache_mla
);
}
}
TORCH_LIBRARY_EXPAND
(
CONCAT
(
TORCH_EXTENSION_NAME
,
_utils
),
utils
)
{
TORCH_LIBRARY_EXPAND
(
CONCAT
(
TORCH_EXTENSION_NAME
,
_utils
),
utils
)
{
...
@@ -157,4 +212,12 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _utils), utils) {
...
@@ -157,4 +212,12 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _utils), utils) {
utils
.
def
(
"init_cpu_threads_env(str cpu_ids) -> str"
,
&
init_cpu_threads_env
);
utils
.
def
(
"init_cpu_threads_env(str cpu_ids) -> str"
,
&
init_cpu_threads_env
);
}
}
TORCH_LIBRARY_EXPAND
(
CONCAT
(
TORCH_EXTENSION_NAME
,
_cpu
),
cpu_ops
)
{
cpu_ops
.
def
(
"mla_decode_kvcache("
" Tensor! out, Tensor query, Tensor kv_cache,"
" float scale, Tensor block_tables, Tensor seq_lens) -> ()"
);
cpu_ops
.
impl
(
"mla_decode_kvcache"
,
torch
::
kCPU
,
&
mla_decode_kvcache
);
}
REGISTER_EXTENSION
(
TORCH_EXTENSION_NAME
)
REGISTER_EXTENSION
(
TORCH_EXTENSION_NAME
)
csrc/cpu/utils.cpp
View file @
fcfc474d
...
@@ -18,7 +18,7 @@ std::string init_cpu_threads_env(const std::string& cpu_ids) {
...
@@ -18,7 +18,7 @@ std::string init_cpu_threads_env(const std::string& cpu_ids) {
#ifndef VLLM_NUMA_DISABLED
#ifndef VLLM_NUMA_DISABLED
std
::
string
init_cpu_threads_env
(
const
std
::
string
&
cpu_ids
)
{
std
::
string
init_cpu_threads_env
(
const
std
::
string
&
cpu_ids
)
{
bitmask
*
omp_cpu_mask
=
numa_parse_cpustring
(
cpu_ids
.
c_str
());
bitmask
*
omp_cpu_mask
=
numa_parse_cpustring
_all
(
cpu_ids
.
c_str
());
TORCH_CHECK
(
omp_cpu_mask
->
size
>
0
);
TORCH_CHECK
(
omp_cpu_mask
->
size
>
0
);
std
::
vector
<
int
>
omp_cpu_ids
;
std
::
vector
<
int
>
omp_cpu_ids
;
omp_cpu_ids
.
reserve
(
omp_cpu_mask
->
size
);
omp_cpu_ids
.
reserve
(
omp_cpu_mask
->
size
);
...
...
csrc/cuda_view.cu
0 → 100644
View file @
fcfc474d
#include <torch/all.h>
#include <torch/cuda.h>
#include <cuda_runtime.h>
// This function assumes that `cpu_tensor` is a CPU tensor allocated with pinned
// memory, and that UVA (Unified Virtual Addressing) is enabled.
torch
::
Tensor
get_cuda_view_from_cpu_tensor
(
torch
::
Tensor
&
cpu_tensor
)
{
TORCH_CHECK
(
cpu_tensor
.
device
().
is_cpu
(),
"Input tensor must be on CPU"
);
// Get raw host pointer from CPU tensor
void
*
host_ptr
=
cpu_tensor
.
data_ptr
();
// Get a device pointer corresponding to the pinned host memory
void
*
device_ptr
=
nullptr
;
cudaError_t
err
=
cudaHostGetDevicePointer
(
&
device_ptr
,
host_ptr
,
0
);
TORCH_CHECK
(
err
==
cudaSuccess
,
"cudaHostGetDevicePointer failed: "
,
cudaGetErrorString
(
err
));
// We'll use the same sizes, strides, and dtype as the CPU tensor.
// TODO: check if layout is respected.
auto
sizes
=
cpu_tensor
.
sizes
();
auto
strides
=
cpu_tensor
.
strides
();
auto
options
=
cpu_tensor
.
options
().
device
(
torch
::
kCUDA
);
// from_blob signature: from_blob(void *data, IntArrayRef sizes, ..., Deleter,
// const TensorOptions &) Provide a no-op deleter. The CPU tensor holds the
// memory, so we don't free it here.
auto
deleter
=
[](
void
*
)
{
// no-op, since the memory is owned by the original CPU tensor
};
torch
::
Tensor
cuda_tensor
=
torch
::
from_blob
(
device_ptr
,
sizes
,
strides
,
deleter
,
options
);
TORCH_CHECK
(
cuda_tensor
.
device
().
is_cuda
(),
"Resulting tensor is not on CUDA device"
);
return
cuda_tensor
;
}
csrc/custom_all_reduce.cu
View file @
fcfc474d
...
@@ -12,7 +12,7 @@ static_assert(sizeof(void*) == sizeof(fptr_t));
...
@@ -12,7 +12,7 @@ static_assert(sizeof(void*) == sizeof(fptr_t));
fptr_t
init_custom_ar
(
const
std
::
vector
<
fptr_t
>&
fake_ipc_ptrs
,
fptr_t
init_custom_ar
(
const
std
::
vector
<
fptr_t
>&
fake_ipc_ptrs
,
torch
::
Tensor
&
rank_data
,
int64_t
rank
,
torch
::
Tensor
&
rank_data
,
int64_t
rank
,
bool
full
_nvlink
)
{
bool
full
y_connected
)
{
int
world_size
=
fake_ipc_ptrs
.
size
();
int
world_size
=
fake_ipc_ptrs
.
size
();
if
(
world_size
>
8
)
if
(
world_size
>
8
)
throw
std
::
invalid_argument
(
"world size > 8 is not supported"
);
throw
std
::
invalid_argument
(
"world size > 8 is not supported"
);
...
@@ -27,7 +27,7 @@ fptr_t init_custom_ar(const std::vector<fptr_t>& fake_ipc_ptrs,
...
@@ -27,7 +27,7 @@ fptr_t init_custom_ar(const std::vector<fptr_t>& fake_ipc_ptrs,
}
}
return
(
fptr_t
)
new
vllm
::
CustomAllreduce
(
ipc_ptrs
,
rank_data
.
data_ptr
(),
return
(
fptr_t
)
new
vllm
::
CustomAllreduce
(
ipc_ptrs
,
rank_data
.
data_ptr
(),
rank_data
.
numel
(),
rank
,
world_size
,
rank_data
.
numel
(),
rank
,
world_size
,
full
_nvlink
);
full
y_connected
);
}
}
/**
/**
...
@@ -144,34 +144,38 @@ void register_graph_buffers(fptr_t _fa,
...
@@ -144,34 +144,38 @@ void register_graph_buffers(fptr_t _fa,
}
}
std
::
tuple
<
fptr_t
,
torch
::
Tensor
>
allocate_shared_buffer_and_handle
(
std
::
tuple
<
fptr_t
,
torch
::
Tensor
>
allocate_shared_buffer_and_handle
(
int64_t
size
)
{
int64_t
size
)
{
auto
device_index
=
c10
::
cuda
::
current_device
();
auto
device_index
=
c10
::
cuda
::
current_device
();
at
::
DeviceGuard
device_guard
(
at
::
Device
(
at
::
DeviceType
::
CUDA
,
device_index
));
at
::
DeviceGuard
device_guard
(
at
::
Device
(
at
::
DeviceType
::
CUDA
,
device_index
));
void
*
buffer
;
void
*
buffer
;
cudaStreamCaptureMode
mode
=
cudaStreamCaptureModeRelaxed
;
cudaStreamCaptureMode
mode
=
cudaStreamCaptureModeRelaxed
;
auto
stream
=
c10
::
cuda
::
getCurrentCUDAStream
().
stream
();
auto
stream
=
c10
::
cuda
::
getCurrentCUDAStream
().
stream
();
AT_CUDA_CHECK
(
cudaThreadExchangeStreamCaptureMode
(
&
mode
));
AT_CUDA_CHECK
(
cudaThreadExchangeStreamCaptureMode
(
&
mode
));
// Allocate buffer
#if defined(USE_ROCM)
#if defined(USE_ROCM)
// data buffers need to be "uncached" for signal on MI200
// data buffers need to be "uncached" for signal on MI200
AT_CUDA_CHECK
(
AT_CUDA_CHECK
(
hipExtMallocWithFlags
((
void
**
)
&
buffer
,
size
,
hipDeviceMallocUncached
));
hipExtMallocWithFlags
((
void
**
)
&
buffer
,
size
,
hipDeviceMallocUncached
));
#else
#else
AT_CUDA_CHECK
(
cudaMalloc
((
void
**
)
&
buffer
,
size
));
AT_CUDA_CHECK
(
cudaMalloc
((
void
**
)
&
buffer
,
size
));
#endif
#endif
AT_CUDA_CHECK
(
cudaMemsetAsync
(
buffer
,
0
,
size
,
stream
));
AT_CUDA_CHECK
(
cudaMemsetAsync
(
buffer
,
0
,
size
,
stream
));
AT_CUDA_CHECK
(
cudaStreamSynchronize
(
stream
));
AT_CUDA_CHECK
(
cudaStreamSynchronize
(
stream
));
AT_CUDA_CHECK
(
cudaThreadExchangeStreamCaptureMode
(
&
mode
));
AT_CUDA_CHECK
(
cudaThreadExchangeStreamCaptureMode
(
&
mode
));
// Create IPC memhandle for the allocated buffer.
auto
options
=
// Will use it in open_mem_handle.
torch
::
TensorOptions
().
dtype
(
torch
::
kUInt8
).
device
(
torch
::
kCPU
);
auto
options
=
auto
handle
=
torch
::
TensorOptions
().
dtype
(
torch
::
kUInt8
).
device
(
torch
::
kCPU
);
torch
::
empty
({
static_cast
<
int64_t
>
(
sizeof
(
cudaIpcMemHandle_t
))},
options
);
auto
handle
=
AT_CUDA_CHECK
(
torch
::
empty
({
static_cast
<
int64_t
>
(
sizeof
(
cudaIpcMemHandle_t
))},
options
);
cudaIpcGetMemHandle
((
cudaIpcMemHandle_t
*
)
handle
.
data_ptr
(),
buffer
));
AT_CUDA_CHECK
(
cudaIpcGetMemHandle
((
cudaIpcMemHandle_t
*
)
handle
.
data_ptr
(),
buffer
));
return
std
::
make_tuple
(
reinterpret_cast
<
fptr_t
>
(
buffer
),
handle
);
return
std
::
make_tuple
(
reinterpret_cast
<
fptr_t
>
(
buffer
),
handle
);
}
}
fptr_t
open_mem_handle
(
torch
::
Tensor
&
mem_handle
)
{
fptr_t
open_mem_handle
(
torch
::
Tensor
&
mem_handle
)
{
void
*
ipc_ptr
;
void
*
ipc_ptr
;
AT_CUDA_CHECK
(
cudaIpcOpenMemHandle
(
AT_CUDA_CHECK
(
cudaIpcOpenMemHandle
(
...
@@ -182,4 +186,4 @@ fptr_t open_mem_handle(torch::Tensor& mem_handle) {
...
@@ -182,4 +186,4 @@ fptr_t open_mem_handle(torch::Tensor& mem_handle) {
void
free_shared_buffer
(
fptr_t
buffer
)
{
void
free_shared_buffer
(
fptr_t
buffer
)
{
AT_CUDA_CHECK
(
cudaFree
(
reinterpret_cast
<
void
*>
(
buffer
)));
AT_CUDA_CHECK
(
cudaFree
(
reinterpret_cast
<
void
*>
(
buffer
)));
}
}
\ No newline at end of file
csrc/custom_all_reduce.cuh
View file @
fcfc474d
...
@@ -33,8 +33,7 @@ constexpr int kMaxBlocks = 36;
...
@@ -33,8 +33,7 @@ constexpr int kMaxBlocks = 36;
// Default number of blocks in allreduce kernel.
// Default number of blocks in allreduce kernel.
#ifndef USE_ROCM
#ifndef USE_ROCM
const
int
defaultBlockLimit
=
36
;
const
int
defaultBlockLimit
=
36
;
CUpointer_attribute
rangeStartAddrAttr
=
CUpointer_attribute
rangeStartAddrAttr
=
CU_POINTER_ATTRIBUTE_RANGE_START_ADDR
;
CUDA_POINTER_ATTRIBUTE_RANGE_START_ADDR
;
#else
#else
const
int
defaultBlockLimit
=
16
;
const
int
defaultBlockLimit
=
16
;
hipPointer_attribute
rangeStartAddrAttr
=
hipPointer_attribute
rangeStartAddrAttr
=
...
@@ -45,11 +44,12 @@ hipPointer_attribute rangeStartAddrAttr =
...
@@ -45,11 +44,12 @@ hipPointer_attribute rangeStartAddrAttr =
// well-defined behavior.
// well-defined behavior.
using
FlagType
=
uint32_t
;
using
FlagType
=
uint32_t
;
// Two sets of peer counters are needed for two syncs. The reason is that
// Two sets of peer counters are needed for two syncs: starting and ending an
// it's possible for peer GPU block to arrive at the second sync point while
// operation. The reason is that it's possible for peer GPU block to arrive at
// the current GPU block haven't passed the first sync point. Thus, peer GPU
// the second sync point while the current GPU block haven't passed the first
// may write counter+1 while current GPU is busy waiting for counter. We use
// sync point. Thus, peer GPU may write counter+1 while current GPU is busy
// alternating counter array to avoid this possibility.
// waiting for counter. We use alternating counter array to avoid this
// possibility.
struct
Signal
{
struct
Signal
{
alignas
(
128
)
FlagType
start
[
kMaxBlocks
][
8
];
alignas
(
128
)
FlagType
start
[
kMaxBlocks
][
8
];
alignas
(
128
)
FlagType
end
[
kMaxBlocks
][
8
];
alignas
(
128
)
FlagType
end
[
kMaxBlocks
][
8
];
...
@@ -195,7 +195,8 @@ static DINLINE FlagType ld_flag_volatile(FlagType* flag_addr) {
...
@@ -195,7 +195,8 @@ static DINLINE FlagType ld_flag_volatile(FlagType* flag_addr) {
// prior memory accesses. Note: volatile writes will not be reordered against
// prior memory accesses. Note: volatile writes will not be reordered against
// other volatile writes.
// other volatile writes.
template
<
int
ngpus
>
template
<
int
ngpus
>
DINLINE
void
start_sync
(
const
RankSignals
&
sg
,
Signal
*
self_sg
,
int
rank
)
{
DINLINE
void
barrier_at_start
(
const
RankSignals
&
sg
,
Signal
*
self_sg
,
int
rank
)
{
uint32_t
flag
=
self_sg
->
_flag
[
blockIdx
.
x
]
+
1
;
uint32_t
flag
=
self_sg
->
_flag
[
blockIdx
.
x
]
+
1
;
if
(
threadIdx
.
x
<
ngpus
)
{
if
(
threadIdx
.
x
<
ngpus
)
{
auto
peer_counter_ptr
=
&
sg
.
signals
[
threadIdx
.
x
]
->
start
[
blockIdx
.
x
][
rank
];
auto
peer_counter_ptr
=
&
sg
.
signals
[
threadIdx
.
x
]
->
start
[
blockIdx
.
x
][
rank
];
...
@@ -215,7 +216,7 @@ DINLINE void start_sync(const RankSignals& sg, Signal* self_sg, int rank) {
...
@@ -215,7 +216,7 @@ DINLINE void start_sync(const RankSignals& sg, Signal* self_sg, int rank) {
// synchronization barrier, we don't need to make any visibility guarantees
// synchronization barrier, we don't need to make any visibility guarantees
// for prior memory accesses.
// for prior memory accesses.
template
<
int
ngpus
,
bool
final_sync
=
false
>
template
<
int
ngpus
,
bool
final_sync
=
false
>
DINLINE
void
end_sync
(
const
RankSignals
&
sg
,
Signal
*
self_sg
,
int
rank
)
{
DINLINE
void
barrier_at_end
(
const
RankSignals
&
sg
,
Signal
*
self_sg
,
int
rank
)
{
__syncthreads
();
__syncthreads
();
uint32_t
flag
=
self_sg
->
_flag
[
blockIdx
.
x
]
+
1
;
uint32_t
flag
=
self_sg
->
_flag
[
blockIdx
.
x
]
+
1
;
if
(
threadIdx
.
x
<
ngpus
)
{
if
(
threadIdx
.
x
<
ngpus
)
{
...
@@ -240,20 +241,20 @@ DINLINE void end_sync(const RankSignals& sg, Signal* self_sg, int rank) {
...
@@ -240,20 +241,20 @@ DINLINE void end_sync(const RankSignals& sg, Signal* self_sg, int rank) {
#else
#else
template
<
int
ngpus
>
template
<
int
ngpus
>
DINLINE
void
start_sync
(
const
RankSignals
&
sg
,
Signal
*
self_sg
,
int
rank
)
{
DINLINE
void
barrier_at_start
(
const
RankSignals
&
sg
,
Signal
*
self_sg
,
int
rank
)
{
uint32_t
flag
=
self_sg
->
_flag
[
blockIdx
.
x
]
+
1
;
uint32_t
flag
=
self_sg
->
_flag
[
blockIdx
.
x
]
+
1
;
if
(
threadIdx
.
x
<
ngpus
)
{
if
(
threadIdx
.
x
<
ngpus
)
{
// simultaneously write to the corresponding flag of all ranks.
// simultaneously write to the corresponding flag of all ranks.
// Latency = 1 p2p write
// Latency = 1 p2p write
// __scoped_atomic_store_n(&sg.signals[threadIdx.x]->start[blockIdx.x][rank],
// __scoped_atomic_store_n(&sg.signals[threadIdx.x]->start[blockIdx.x][rank],
// flag, __ATOMIC_RELAXED, __MEMORY_SCOPE_SYSTEM);
// flag, __ATOMIC_RELAXED, __MEMORY_SCOPE_SYSTEM);
// // wait until we got true from all ranks
// while (__scoped_atomic_load_n(&self_sg->start[blockIdx.x][threadIdx.x],
// __ATOMIC_RELAXED,
// __MEMORY_SCOPE_DEVICE) < flag);
__atomic_store_n
(
&
sg
.
signals
[
threadIdx
.
x
]
->
start
[
blockIdx
.
x
][
rank
],
flag
,
__atomic_store_n
(
&
sg
.
signals
[
threadIdx
.
x
]
->
start
[
blockIdx
.
x
][
rank
],
flag
,
__ATOMIC_RELAXED
);
__ATOMIC_RELAXED
);
// wait until we got true from all ranks
// wait until we got true from all ranks
// while (__scoped_atomic_load_n(&self_sg->start[blockIdx.x][threadIdx.x],
// __ATOMIC_RELAXED,
// __MEMORY_SCOPE_DEVICE) < flag);
while
(
__atomic_load_n
(
&
self_sg
->
start
[
blockIdx
.
x
][
threadIdx
.
x
],
while
(
__atomic_load_n
(
&
self_sg
->
start
[
blockIdx
.
x
][
threadIdx
.
x
],
__ATOMIC_RELAXED
)
<
flag
);
__ATOMIC_RELAXED
)
<
flag
);
}
}
...
@@ -263,7 +264,7 @@ DINLINE void start_sync(const RankSignals& sg, Signal* self_sg, int rank) {
...
@@ -263,7 +264,7 @@ DINLINE void start_sync(const RankSignals& sg, Signal* self_sg, int rank) {
}
}
template
<
int
ngpus
,
bool
final_sync
=
false
>
template
<
int
ngpus
,
bool
final_sync
=
false
>
DINLINE
void
end_sync
(
const
RankSignals
&
sg
,
Signal
*
self_sg
,
int
rank
)
{
DINLINE
void
barrier_at_end
(
const
RankSignals
&
sg
,
Signal
*
self_sg
,
int
rank
)
{
__syncthreads
();
__syncthreads
();
uint32_t
flag
=
self_sg
->
_flag
[
blockIdx
.
x
]
+
1
;
uint32_t
flag
=
self_sg
->
_flag
[
blockIdx
.
x
]
+
1
;
if
(
threadIdx
.
x
<
ngpus
)
{
if
(
threadIdx
.
x
<
ngpus
)
{
...
@@ -273,14 +274,13 @@ DINLINE void end_sync(const RankSignals& sg, Signal* self_sg, int rank) {
...
@@ -273,14 +274,13 @@ DINLINE void end_sync(const RankSignals& sg, Signal* self_sg, int rank) {
// flag,
// flag,
// final_sync ? __ATOMIC_RELAXED : __ATOMIC_RELEASE,
// final_sync ? __ATOMIC_RELAXED : __ATOMIC_RELEASE,
// __MEMORY_SCOPE_SYSTEM);
// __MEMORY_SCOPE_SYSTEM);
// // wait until we got true from all ranks
__atomic_store_n
(
&
sg
.
signals
[
threadIdx
.
x
]
->
end
[
blockIdx
.
x
][
rank
],
flag
,
final_sync
?
__ATOMIC_RELAXED
:
__ATOMIC_RELEASE
);
// wait until we got true from all ranks
// while (
// while (
// __scoped_atomic_load_n(&self_sg->end[blockIdx.x][threadIdx.x],
// __scoped_atomic_load_n(&self_sg->end[blockIdx.x][threadIdx.x],
// final_sync ? __ATOMIC_RELAXED : __ATOMIC_ACQUIRE,
// final_sync ? __ATOMIC_RELAXED : __ATOMIC_ACQUIRE,
// __MEMORY_SCOPE_DEVICE) < flag);
// __MEMORY_SCOPE_DEVICE) < flag);
__atomic_store_n
(
&
sg
.
signals
[
threadIdx
.
x
]
->
end
[
blockIdx
.
x
][
rank
],
flag
,
final_sync
?
__ATOMIC_RELAXED
:
__ATOMIC_RELEASE
);
// wait until we got true from all ranks
while
(
__atomic_load_n
(
&
self_sg
->
end
[
blockIdx
.
x
][
threadIdx
.
x
],
while
(
__atomic_load_n
(
&
self_sg
->
end
[
blockIdx
.
x
][
threadIdx
.
x
],
final_sync
?
__ATOMIC_RELAXED
:
__ATOMIC_ACQUIRE
)
<
final_sync
?
__ATOMIC_RELAXED
:
__ATOMIC_ACQUIRE
)
<
flag
);
flag
);
...
@@ -311,13 +311,13 @@ __global__ void __launch_bounds__(512, 1)
...
@@ -311,13 +311,13 @@ __global__ void __launch_bounds__(512, 1)
// note: we don't reorder the address so the accumulation order is the same
// note: we don't reorder the address so the accumulation order is the same
// for all ranks, ensuring bitwise identical results
// for all ranks, ensuring bitwise identical results
auto
dp
=
*
_dp
;
auto
dp
=
*
_dp
;
start
_sync
<
ngpus
>
(
sg
,
self_sg
,
rank
);
barrier_at_
start
<
ngpus
>
(
sg
,
self_sg
,
rank
);
// do the actual reduction
// do the actual reduction
for
(
int
idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
idx
<
size
;
for
(
int
idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
idx
<
size
;
idx
+=
gridDim
.
x
*
blockDim
.
x
)
{
idx
+=
gridDim
.
x
*
blockDim
.
x
)
{
((
P
*
)
result
)[
idx
]
=
packed_reduce
<
P
,
ngpus
,
A
>
((
const
P
**
)
&
dp
.
ptrs
[
0
],
idx
);
((
P
*
)
result
)[
idx
]
=
packed_reduce
<
P
,
ngpus
,
A
>
((
const
P
**
)
&
dp
.
ptrs
[
0
],
idx
);
}
}
end_sync
<
ngpus
,
true
>
(
sg
,
self_sg
,
rank
);
barrier_at_end
<
ngpus
,
true
>
(
sg
,
self_sg
,
rank
);
}
}
template
<
typename
P
>
template
<
typename
P
>
...
@@ -346,19 +346,20 @@ __global__ void __launch_bounds__(512, 1)
...
@@ -346,19 +346,20 @@ __global__ void __launch_bounds__(512, 1)
tmps
[
i
]
=
get_tmp_buf
<
P
>
(
sg
.
signals
[
target
]);
tmps
[
i
]
=
get_tmp_buf
<
P
>
(
sg
.
signals
[
target
]);
}
}
auto
tmp_out
=
tmps
[
0
];
auto
tmp_out
=
tmps
[
0
];
start
_sync
<
ngpus
>
(
sg
,
self_sg
,
rank
);
barrier_at_
start
<
ngpus
>
(
sg
,
self_sg
,
rank
);
// stage 1: reduce scatter
// stage 1: reduce scatter
for
(
int
idx
=
start
+
tid
;
idx
<
end
;
idx
+=
stride
)
{
for
(
int
idx
=
start
+
tid
;
idx
<
end
;
idx
+=
stride
)
{
tmp_out
[
idx
-
start
]
=
packed_reduce
<
P
,
ngpus
,
A
>
(
ptrs
,
idx
);
tmp_out
[
idx
-
start
]
=
packed_reduce
<
P
,
ngpus
,
A
>
(
ptrs
,
idx
);
}
}
end_sync
<
ngpus
>
(
sg
,
self_sg
,
rank
);
barrier_at_end
<
ngpus
>
(
sg
,
self_sg
,
rank
);
// stage 2: allgather. Note: it's important to match the tid between
// stage 2: allgather. Note: it's important to match the tid between
// the two stages, because visibility across devices is only guaranteed
// the two stages, because visibility across devices is only guaranteed
// between threads that have the same tid. If thread i computes the sum of
// between threads that have the same tid. If thread i computes the sum of
// start + i in the first stage, then thread i also gathers start + i from all
// start + i in the first stage, then thread i also gathers start + i from
// ranks.
// all ranks.
for
(
int
idx
=
tid
;
idx
<
largest_part
;
idx
+=
stride
)
{
for
(
int
idx
=
tid
;
idx
<
largest_part
;
idx
+=
stride
)
{
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
ngpus
;
i
++
)
{
for
(
int
i
=
0
;
i
<
ngpus
;
i
++
)
{
...
@@ -379,7 +380,8 @@ class CustomAllreduce {
...
@@ -379,7 +380,8 @@ class CustomAllreduce {
public:
public:
int
rank_
;
int
rank_
;
int
world_size_
;
int
world_size_
;
bool
full_nvlink_
;
// Full NVLink or xGMI connection between GPUs.
bool
fully_connected_
;
RankSignals
sg_
;
RankSignals
sg_
;
// Stores an map from a pointer to its peer pointers from all ranks.
// Stores an map from a pointer to its peer pointers from all ranks.
...
@@ -388,12 +390,12 @@ class CustomAllreduce {
...
@@ -388,12 +390,12 @@ class CustomAllreduce {
// Stores rank data from all ranks. This is mainly for cuda graph purposes.
// Stores rank data from all ranks. This is mainly for cuda graph purposes.
// For cuda graph to work, all kernel arguments must be fixed during graph
// For cuda graph to work, all kernel arguments must be fixed during graph
// capture time. However, the peer pointers are not known during graph
capture
// capture time. However, the peer pointers are not known during graph
// time. Therefore, during capture, we increment the rank data
pointer and use
//
capture
time. Therefore, during capture, we increment the rank data
// that as the argument to the kernel. The kernel arguments
are stored in
//
pointer and use
that as the argument to the kernel. The kernel arguments
// graph_unreg_buffers_. The actual peer pointers will be
filled in at the
//
are stored in
graph_unreg_buffers_. The actual peer pointers will be
// memory pointed to by the pointers in
graph_unreg_buffers_ when
//
filled in at the
memory pointed to by the pointers in
// the IPC handles are exchanged between ranks.
//
graph_unreg_buffers_ when
the IPC handles are exchanged between ranks.
//
//
// The overall process looks like this:
// The overall process looks like this:
// 1. Graph capture.
// 1. Graph capture.
...
@@ -411,17 +413,18 @@ class CustomAllreduce {
...
@@ -411,17 +413,18 @@ class CustomAllreduce {
* Signals are an array of ipc-enabled buffers from all ranks.
* Signals are an array of ipc-enabled buffers from all ranks.
* For each of the buffer, the layout is as follows:
* For each of the buffer, the layout is as follows:
* | -- sizeof(Signal) -- | ------ a few MB ----- |
* | -- sizeof(Signal) -- | ------ a few MB ----- |
* The first section is for allreduce synchronization, and the second section
* The first section is for allreduce synchronization, and the second
* is for storing the intermediate results required by some allreduce algos.
* section is for storing the intermediate results required by some
* allreduce algos.
*
*
* Note: this class does not own any device memory. Any required buffers
* Note: this class does not own any device memory. Any required buffers
* are passed in from the constructor.
* are passed in from the constructor.
*/
*/
CustomAllreduce
(
Signal
**
signals
,
void
*
rank_data
,
size_t
rank_data_sz
,
CustomAllreduce
(
Signal
**
signals
,
void
*
rank_data
,
size_t
rank_data_sz
,
int
rank
,
int
world_size
,
bool
full
_nvlink
=
true
)
int
rank
,
int
world_size
,
bool
full
y_connected
=
true
)
:
rank_
(
rank
),
:
rank_
(
rank
),
world_size_
(
world_size
),
world_size_
(
world_size
),
full
_nvlink_
(
full_nvlink
),
full
y_connected_
(
fully_connected
),
self_sg_
(
signals
[
rank
]),
self_sg_
(
signals
[
rank
]),
d_rank_data_base_
(
reinterpret_cast
<
RankData
*>
(
rank_data
)),
d_rank_data_base_
(
reinterpret_cast
<
RankData
*>
(
rank_data
)),
d_rank_data_end_
(
d_rank_data_base_
+
rank_data_sz
/
sizeof
(
RankData
))
{
d_rank_data_end_
(
d_rank_data_base_
+
rank_data_sz
/
sizeof
(
RankData
))
{
...
@@ -487,11 +490,11 @@ class CustomAllreduce {
...
@@ -487,11 +490,11 @@ class CustomAllreduce {
// Note: when registering graph buffers, we intentionally choose to not
// Note: when registering graph buffers, we intentionally choose to not
// deduplicate the addresses. That means if the allocator reuses some
// deduplicate the addresses. That means if the allocator reuses some
// addresses, they will be registered again. This is to account for the
remote
// addresses, they will be registered again. This is to account for the
// possibility of different allocation patterns between ranks. For
example,
//
remote
possibility of different allocation patterns between ranks. For
// rank 1 may get the same input address for the second allreduce,
but rank 2
//
example,
rank 1 may get the same input address for the second allreduce,
// got a different address. IPC handles have internal reference
counting
//
but rank 2
got a different address. IPC handles have internal reference
// mechanism so overhead should be small.
//
counting
mechanism so overhead should be small.
void
register_graph_buffers
(
void
register_graph_buffers
(
const
std
::
vector
<
std
::
string
>&
handles
,
const
std
::
vector
<
std
::
string
>&
handles
,
const
std
::
vector
<
std
::
vector
<
int64_t
>>&
offsets
)
{
const
std
::
vector
<
std
::
vector
<
int64_t
>>&
offsets
)
{
...
@@ -522,11 +525,11 @@ class CustomAllreduce {
...
@@ -522,11 +525,11 @@ class CustomAllreduce {
/**
/**
* Performs allreduce, assuming input has already been registered.
* Performs allreduce, assuming input has already been registered.
*
*
* Block and grid default configs are results after careful grid search.
Using
* Block and grid default configs are results after careful grid search.
* 36 blocks give the best or close to the best runtime on the devices
I
*
Using
36 blocks give the best or close to the best runtime on the devices
* tried: A100, A10, A30, T4, V100. You'll notice that NCCL kernels also
only
*
I
tried: A100, A10, A30, T4, V100. You'll notice that NCCL kernels also
* take a small amount of SMs. Not quite sure the underlying reason,
but my
*
only
take a small amount of SMs. Not quite sure the underlying reason,
* guess is that too many SMs will cause contention on NVLink bus.
*
but my
guess is that too many SMs will cause contention on NVLink bus.
*/
*/
template
<
typename
T
>
template
<
typename
T
>
void
allreduce
(
cudaStream_t
stream
,
T
*
input
,
T
*
output
,
int
size
,
void
allreduce
(
cudaStream_t
stream
,
T
*
input
,
T
*
output
,
int
size
,
...
@@ -568,7 +571,7 @@ class CustomAllreduce {
...
@@ -568,7 +571,7 @@ class CustomAllreduce {
case ngpus: { \
case ngpus: { \
if (world_size_ == 2) { \
if (world_size_ == 2) { \
KL(ngpus, cross_device_reduce_1stage); \
KL(ngpus, cross_device_reduce_1stage); \
} else if (full
_nvlink_) {
\
} else if (full
y_connected_) {
\
if ((world_size_ <= 4 && bytes < 512 * 1024) || \
if ((world_size_ <= 4 && bytes < 512 * 1024) || \
(world_size_ <= 8 && bytes < 256 * 1024)) { \
(world_size_ <= 8 && bytes < 256 * 1024)) { \
KL(ngpus, cross_device_reduce_1stage); \
KL(ngpus, cross_device_reduce_1stage); \
...
@@ -601,10 +604,11 @@ class CustomAllreduce {
...
@@ -601,10 +604,11 @@ class CustomAllreduce {
}
}
}
}
};
};
/**
/**
* To inspect PTX/SASS, copy paste this header file to compiler explorer and
add
* To inspect PTX/SASS, copy paste this header file to compiler explorer and
a template instantiation:
add
a template instantiation:
* template void vllm::CustomAllreduce::allreduce<half>(cudaStream_t, half *,
* template void vllm::CustomAllreduce::allreduce<half>(cudaStream_t, half *,
half *, int, int, int);
half *, int, int, int);
*/
*/
}
// namespace vllm
}
// namespace vllm
\ No newline at end of file
csrc/custom_all_reduce_test.cu
View file @
fcfc474d
/**
/**
* This is a standalone test for custom allreduce.
* This is a standalone test for custom allreduce.
* To compile, make sure you have MPI and NCCL installed in your system.
* To compile, make sure you have MPI and NCCL installed in your system.
* export MPI_HOME=
xxx
* export MPI_HOME=
XXX
* nvcc -O2 -arch=native -std=c++17 custom_all_reduce_test.cu -o
* nvcc -O2 -arch=native -std=c++17 custom_all_reduce_test.cu -o
* custom_all_reduce_test -lnccl -I${MPI_HOME} -lmpi
* custom_all_reduce_test -lnccl -I${MPI_HOME}
/include
-lmpi
*
*
* Warning: this C++ test is not designed to be very readable and was used
* Warning: this C++ test is not designed to be very readable and was used
* during the rapid prototyping process.
* during the rapid prototyping process.
...
@@ -20,6 +20,7 @@
...
@@ -20,6 +20,7 @@
#include <vector>
#include <vector>
#include "cuda_profiler_api.h"
#include "cuda_profiler_api.h"
#include "custom_all_reduce.cuh"
#include "mpi.h"
#include "mpi.h"
#ifdef USE_ROCM
#ifdef USE_ROCM
#include <hip/hip_bf16.h>
#include <hip/hip_bf16.h>
...
@@ -50,8 +51,8 @@ typedef __hip_bfloat16 nv_bfloat16;
...
@@ -50,8 +51,8 @@ typedef __hip_bfloat16 nv_bfloat16;
} \
} \
} while (0)
} while (0)
__global__
void
dummy_kernel
()
{
#ifdef USE_ROCM
#ifdef USE_ROCM
__global__
void
dummy_kernel
()
{
for
(
int
i
=
0
;
i
<
100
;
i
++
)
{
for
(
int
i
=
0
;
i
<
100
;
i
++
)
{
uint64_t
start
=
wall_clock64
();
uint64_t
start
=
wall_clock64
();
uint64_t
cycles_elapsed
;
uint64_t
cycles_elapsed
;
...
@@ -59,10 +60,20 @@ __global__ void dummy_kernel() {
...
@@ -59,10 +60,20 @@ __global__ void dummy_kernel() {
cycles_elapsed
=
wall_clock64
()
-
start
;
cycles_elapsed
=
wall_clock64
()
-
start
;
}
while
(
cycles_elapsed
<
100
);
}
while
(
cycles_elapsed
<
100
);
}
}
for
(
int
i
=
0
;
i
<
100
;
i
++
)
__nanosleep
(
1000000
);
// 100ms
}
#else
#else
__global__
void
dummy_kernel
()
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700
for
(
int
i
=
0
;
i
<
100
;
i
++
)
__nanosleep
(
1000000
);
// 100ms
for
(
int
i
=
0
;
i
<
100
;
i
++
)
__nanosleep
(
1000000
);
// 100ms
#endif
#else
for
(
int
i
=
0
;
i
<
100
;
i
++
)
{
long
long
int
start
=
clock64
();
while
(
clock64
()
-
start
<
150000000
);
// approximately 98.4ms on P40
}
#endif
}
}
#endif
template
<
typename
T
>
template
<
typename
T
>
__global__
void
set_data
(
T
*
data
,
int
size
,
int
myRank
)
{
__global__
void
set_data
(
T
*
data
,
int
size
,
int
myRank
)
{
...
@@ -151,24 +162,26 @@ void run(int myRank, int nRanks, ncclComm_t& comm, int threads, int block_limit,
...
@@ -151,24 +162,26 @@ void run(int myRank, int nRanks, ncclComm_t& comm, int threads, int block_limit,
void
*
rank_data
;
void
*
rank_data
;
size_t
rank_data_sz
=
16
*
1024
*
1024
;
size_t
rank_data_sz
=
16
*
1024
*
1024
;
CUDACHECK
(
cudaMalloc
(
&
rank_data
,
rank_data_sz
));
CUDACHECK
(
cudaMalloc
(
&
rank_data
,
rank_data_sz
));
std
::
vector
<
int64_t
>
offsets
(
nRanks
,
0
);
vllm
::
Signal
*
ipc_ptrs
[
8
];
vllm
::
CustomAllreduce
fa
(
buffer
,
rank_data
,
rank_data_sz
,
data_handles
,
for
(
int
i
=
0
;
i
<
nRanks
;
i
++
)
{
offsets
,
myRank
);
if
(
i
==
myRank
)
ipc_ptrs
[
i
]
=
buffer
;
else
CUDACHECK
(
cudaIpcOpenMemHandle
((
void
**
)
&
ipc_ptrs
[
i
],
data_handles
[
i
],
cudaIpcMemLazyEnablePeerAccess
));
}
vllm
::
CustomAllreduce
fa
(
ipc_ptrs
,
rank_data
,
rank_data_sz
,
myRank
,
nRanks
);
auto
*
self_data
=
auto
*
self_data
=
reinterpret_cast
<
T
*>
(
reinterpret_cast
<
char
*>
(
buffer
)
+
reinterpret_cast
<
T
*>
(
reinterpret_cast
<
char
*>
(
buffer
)
+
sizeof
(
vllm
::
Signal
)
+
data_size
*
sizeof
(
T
));
sizeof
(
vllm
::
Signal
)
+
data_size
*
sizeof
(
T
));
// hack buffer registration
// hack buffer registration
{
{
std
::
vector
<
std
::
string
>
handles
;
void
*
data
[
8
];
handles
.
reserve
(
nRanks
);
for
(
int
i
=
0
;
i
<
nRanks
;
i
++
)
{
for
(
int
i
=
0
;
i
<
nRanks
;
i
++
)
{
char
*
begin
=
(
char
*
)
&
data_handles
[
i
];
data
[
i
]
=
char
*
end
=
(
char
*
)
&
data_handles
[
i
+
1
];
((
char
*
)
ipc_ptrs
[
i
])
+
sizeof
(
vllm
::
Signal
)
+
data_size
*
sizeof
(
T
);
handles
.
emplace_back
(
begin
,
end
);
}
}
std
::
vector
<
int64_t
>
offsets
(
nRanks
,
fa
.
register_buffer
(
data
);
sizeof
(
vllm
::
Signal
)
+
data_size
*
sizeof
(
T
));
fa
.
register_buffer
(
handles
,
offsets
,
self_data
);
}
}
double
*
ground_truth
;
double
*
ground_truth
;
...
@@ -280,14 +293,14 @@ void run(int myRank, int nRanks, ncclComm_t& comm, int threads, int block_limit,
...
@@ -280,14 +293,14 @@ void run(int myRank, int nRanks, ncclComm_t& comm, int threads, int block_limit,
if
(
diff
>=
4e-2
)
{
if
(
diff
>=
4e-2
)
{
printf
(
printf
(
"Rank %d: Verification mismatch at %lld: %f != (my) %f, gt=%f
\n
"
,
"Rank %d: Verification mismatch at %lld: %f != (my) %f, gt=%f
\n
"
,
myRank
,
j
,
nccl_result
[
j
],
my_result
[
j
],
ground_truth
[
j
]);
myRank
,
j
,
nccl_result
[
j
],
my_result
[
j
],
ground_truth
[
j
]);
break
;
break
;
}
}
}
}
}
}
if
(
myRank
==
0
)
if
(
myRank
==
0
)
printf
(
"Test passed: nGPUs:%d, sz (kb): %d, %d, %d
\n
"
,
nRanks
,
printf
(
"Test passed: nGPUs:%d, sz (kb): %d, %d, %d
\n
"
,
nRanks
,
data_size
*
sizeof
(
T
)
/
1024
,
threads
,
block_limit
);
data_size
*
sizeof
(
T
)
/
1024
,
threads
,
block_limit
);
// long double nccl_diffs = 0.0;
// long double nccl_diffs = 0.0;
// long double my_diffs = 0.0;
// long double my_diffs = 0.0;
// for (int j = 0; j < data_size; j++) {
// for (int j = 0; j < data_size; j++) {
...
@@ -320,27 +333,29 @@ int main(int argc, char** argv) {
...
@@ -320,27 +333,29 @@ int main(int argc, char** argv) {
ncclComm_t
comm
;
ncclComm_t
comm
;
if
(
myRank
==
0
)
ncclGetUniqueId
(
&
id
);
if
(
myRank
==
0
)
ncclGetUniqueId
(
&
id
);
MPICHECK
(
MPI_Bcast
(
static_cast
<
void
*>
(
&
id
),
sizeof
(
id
),
MPI_BYTE
,
0
,
MPICHECK
(
MPI_Bcast
(
static_cast
<
void
*>
(
&
id
),
sizeof
(
id
),
MPI_BYTE
,
0
,
MPI_COMM_WORLD
));
MPI_COMM_WORLD
));
NCCLCHECK
(
ncclCommInitRank
(
&
comm
,
nRanks
,
id
,
myRank
));
NCCLCHECK
(
ncclCommInitRank
(
&
comm
,
nRanks
,
id
,
myRank
));
bool
performance_test
=
true
;
bool
performance_test
=
true
;
cudaProfilerStart
();
cudaProfilerStart
();
// for (int threads : {256, 512}) {
// Uncomment to scan through different block size configs.
// for (int block_limit = 16; block_limit < 112; block_limit += 4) {
// for (int threads : {256, 512, 1024}) {
// run<half>(myRank, nRanks, comm, threads, block_limit, 4096 * 1024);
// for (int block_limit = 16; block_limit < 112; block_limit += 4) {
// }
// run<half>(myRank, nRanks, comm, threads, block_limit, 1024 * 1024,
// }
// performance_test);
// }
// }
#ifdef USE_ROCM
#ifdef USE_ROCM
for
(
int
sz
=
512
;
sz
<=
(
8
<<
20
);
sz
*=
2
)
{
const
int
block_limit
=
16
;
run
<
half
>
(
myRank
,
nRanks
,
comm
,
512
,
16
,
sz
+
8
*
47
,
performance_test
);
}
#else
#else
const
int
block_limit
=
36
;
#endif
// Scan through different sizes to test performance.
for
(
int
sz
=
512
;
sz
<=
(
8
<<
20
);
sz
*=
2
)
{
for
(
int
sz
=
512
;
sz
<=
(
8
<<
20
);
sz
*=
2
)
{
run
<
half
>
(
myRank
,
nRanks
,
comm
,
512
,
36
,
sz
+
8
*
47
,
performance_test
);
run
<
half
>
(
myRank
,
nRanks
,
comm
,
512
,
36
,
sz
+
8
*
47
,
performance_test
);
}
}
#endif
cudaProfilerStop
();
cudaProfilerStop
();
MPICHECK
(
MPI_Finalize
());
MPICHECK
(
MPI_Finalize
());
return
EXIT_SUCCESS
;
return
EXIT_SUCCESS
;
}
}
\ No newline at end of file
csrc/cutlass_extensions/common.hpp
View file @
fcfc474d
...
@@ -48,4 +48,14 @@ struct enable_sm90_or_later : Kernel {
...
@@ -48,4 +48,14 @@ struct enable_sm90_or_later : Kernel {
Kernel
::
operator
()(
std
::
forward
<
Args
>
(
args
)...);
Kernel
::
operator
()(
std
::
forward
<
Args
>
(
args
)...);
#endif
#endif
}
}
};
};
\ No newline at end of file
template
<
typename
Kernel
>
struct
enable_sm90_only
:
Kernel
{
template
<
typename
...
Args
>
CUTLASS_DEVICE
void
operator
()(
Args
&&
...
args
)
{
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ == 900
Kernel
::
operator
()(
std
::
forward
<
Args
>
(
args
)...);
#endif
}
};
csrc/cutlass_extensions/epilogue/broadcast_load_epilogue_array_c3x.hpp
0 → 100644
View file @
fcfc474d
/***************************************************************************************************
* Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights
*reserved. SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice,
*this list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
*ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
*LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
*CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
*SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
*INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
*CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
*ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
*POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
//
// This file is a modified excerpt of
// include/cutlass/epilogue/fusion/sm90_visitor_load_tma_warpspecialized.hpp
// from https://github.com/NVIDIA/cutlass v3.5.0
// It has been modified to support either row/column or scalar broadcasting
// where the tensor being loaded from is always passed in via a device pointer.
// This lets one compiled kernel handle all cases of per-tensor or
// per-channel/per-token quantization.
//
// This interface also allows the scales to be passed in as tensors that
// consistently reside on the device, which avoids an issue with a previous
// implementation where scalars needed to be on the CPU since they
// were passed in via float values. This created a potential performance hazard
// if scales were initially on the device, and caused torch.compile graphs
// breaks when moving scales to the CPU.
//
#pragma once
// Turn off clang-format for the entire file to keep it close to upstream
// clang-format off
#include "cutlass/cutlass.h"
#include "cutlass/arch/barrier.h"
#include "cute/tensor.hpp"
#include "cutlass/epilogue/fusion/sm90_visitor_tma_warpspecialized.hpp"
namespace
cutlass
::
epilogue
::
fusion
{
using
namespace
cute
;
using
namespace
detail
;
// Row vector broadcast
template
<
int
Stages
,
class
CtaTileShapeMNK
,
class
Element
,
class
StrideMNL
=
Stride
<
_0
,
_1
,
_0
>,
int
Alignment
=
128
/
sizeof_bits_v
<
Element
>
>
struct
Sm90RowOrScalarBroadcastArray
{
static_assert
(
Stages
==
0
,
"Row broadcast doesn't support smem usage"
);
static_assert
(
is_static_v
<
decltype
(
take
<
0
,
2
>
(
StrideMNL
{}))
>
);
// batch stride can be dynamic or static
static_assert
(
take
<
0
,
2
>
(
StrideMNL
{})
==
Stride
<
_0
,
_1
>
{});
struct
SharedStorage
{
array_aligned
<
Element
,
size
<
1
>
(
CtaTileShapeMNK
{})
>
smem
;
};
// This struct has been modified to have a bool indicating that ptr_row is a
// scalar that must be broadcast, instead of containing a scalar that is
// valid if ptr_row is null.
struct
Arguments
{
const
Element
*
const
*
ptr_row_array
=
nullptr
;
bool
row_broadcast
=
true
;
StrideMNL
dRow
=
{};
};
using
Params
=
Arguments
;
template
<
class
ProblemShape
>
static
constexpr
Params
to_underlying_arguments
(
ProblemShape
const
&
problem_shape
,
Arguments
const
&
args
,
void
*
workspace
)
{
return
args
;
}
template
<
class
ProblemShape
>
static
bool
can_implement
(
ProblemShape
const
&
problem_shape
,
Arguments
const
&
args
)
{
return
true
;
}
template
<
class
ProblemShape
>
static
size_t
get_workspace_size
(
ProblemShape
const
&
problem_shape
,
Arguments
const
&
args
)
{
return
0
;
}
template
<
class
ProblemShape
>
static
cutlass
::
Status
initialize_workspace
(
ProblemShape
const
&
problem_shape
,
Arguments
const
&
args
,
void
*
workspace
,
cudaStream_t
stream
,
CudaHostAdapter
*
cuda_adapter
=
nullptr
)
{
return
cutlass
::
Status
::
kSuccess
;
}
CUTLASS_HOST_DEVICE
Sm90RowOrScalarBroadcastArray
()
{
}
CUTLASS_HOST_DEVICE
Sm90RowOrScalarBroadcastArray
(
Params
const
&
params
,
SharedStorage
const
&
shared_storage
)
:
params
(
params
)
,
smem
(
const_cast
<
Element
*>
(
shared_storage
.
smem
.
data
()))
{
}
Params
params
;
Element
*
smem
=
nullptr
;
CUTLASS_DEVICE
bool
is_producer_load_needed
()
const
{
return
false
;
}
CUTLASS_DEVICE
bool
is_C_load_needed
()
const
{
return
false
;
}
CUTLASS_DEVICE
bool
is_zero
()
const
{
return
(
!
params
.
row_broadcast
&&
*
(
params
.
ptr_row_array
[
group
])
==
Element
(
0
));
}
template
<
class
...
Args
>
CUTLASS_DEVICE
auto
get_producer_load_callbacks
(
ProducerLoadArgs
<
Args
...
>
const
&
args
)
{
return
EmptyProducerLoadCallbacks
{};
}
template
<
class
GS_GTensor
,
class
GS_STensor
,
class
GS_CTensor
,
class
Tiled_G2S
,
class
SR_STensor
,
class
SR_RTensor
,
class
CTensor
,
class
ThrResidue
,
class
ThrNum
>
struct
ConsumerStoreCallbacks
:
EmptyConsumerStoreCallbacks
{
CUTLASS_DEVICE
ConsumerStoreCallbacks
(
GS_GTensor
tGS_gRow_
,
GS_STensor
tGS_sRow_
,
GS_CTensor
tGS_cRow_
,
Tiled_G2S
tiled_g2s_
,
SR_STensor
tSR_sRow_
,
SR_RTensor
tSR_rRow_
,
CTensor
tCcRow_
,
ThrResidue
residue_tCcRow_
,
ThrNum
thr_num_
,
int
group
,
Params
const
&
params_
)
:
tGS_gRow
(
tGS_gRow_
)
,
tGS_sRow
(
tGS_sRow_
)
,
tGS_cRow
(
tGS_cRow_
)
,
tiled_G2S
(
tiled_g2s_
)
,
tSR_sRow
(
tSR_sRow_
)
,
tSR_rRow
(
tSR_rRow_
)
,
tCcRow
(
tCcRow_
)
,
residue_tCcRow
(
residue_tCcRow_
)
,
group
(
group
)
,
params
(
params_
)
{}
GS_GTensor
tGS_gRow
;
// (CPY,CPY_M,CPY_N)
GS_STensor
tGS_sRow
;
// (CPY,CPY_M,CPY_N)
GS_CTensor
tGS_cRow
;
// (CPY,CPY_M,CPY_N)
Tiled_G2S
tiled_G2S
;
SR_STensor
tSR_sRow
;
// (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
SR_RTensor
tSR_rRow
;
// (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
CTensor
tCcRow
;
// (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
ThrResidue
residue_tCcRow
;
// (m, n)
ThrNum
thr_num
;
int
group
;
Params
const
&
params
;
CUTLASS_DEVICE
void
begin
()
{
if
(
!
params
.
row_broadcast
)
{
fill
(
tSR_rRow
,
*
(
params
.
ptr_row_array
[
group
]));
return
;
}
auto
synchronize
=
[
&
]
()
{
cutlass
::
arch
::
NamedBarrier
::
sync
(
thr_num
,
cutlass
::
arch
::
ReservedNamedBarriers
::
EpilogueBarrier
);
};
Tensor
tGS_gRow_flt
=
filter_zeros
(
tGS_gRow
);
Tensor
tGS_sRow_flt
=
filter_zeros
(
tGS_sRow
);
Tensor
tGS_cRow_flt
=
make_tensor
(
tGS_cRow
.
data
(),
make_layout
(
tGS_gRow_flt
.
shape
(),
tGS_cRow
.
stride
()));
for
(
int
i
=
0
;
i
<
size
(
tGS_gRow_flt
);
++
i
)
{
if
(
get
<
1
>
(
tGS_cRow_flt
(
i
))
>=
size
<
1
>
(
CtaTileShapeMNK
{}))
{
continue
;
// OOB of SMEM,
}
if
(
elem_less
(
tGS_cRow_flt
(
i
),
make_coord
(
get
<
0
>
(
residue_tCcRow
),
get
<
1
>
(
residue_tCcRow
))))
{
tGS_sRow_flt
(
i
)
=
tGS_gRow_flt
(
i
);
}
else
{
tGS_sRow_flt
(
i
)
=
Element
(
0
);
// Set to Zero when OOB so LDS could be issue without any preds.
}
}
synchronize
();
}
CUTLASS_DEVICE
void
begin_loop
(
int
epi_m
,
int
epi_n
)
{
if
(
epi_m
==
0
)
{
// Assumes M-major subtile loop
if
(
!
params
.
row_broadcast
)
return
;
// Do not issue LDS when row is scalar
Tensor
tSR_sRow_flt
=
filter_zeros
(
tSR_sRow
(
_
,
_
,
_
,
epi_m
,
epi_n
));
Tensor
tSR_rRow_flt
=
filter_zeros
(
tSR_rRow
);
copy
(
tSR_sRow_flt
,
tSR_rRow_flt
);
}
}
template
<
typename
ElementAccumulator
,
int
FragmentSize
>
CUTLASS_DEVICE
Array
<
Element
,
FragmentSize
>
visit
(
Array
<
ElementAccumulator
,
FragmentSize
>
const
&
frg_acc
,
int
epi_v
,
int
epi_m
,
int
epi_n
)
{
Array
<
Element
,
FragmentSize
>
frg_row
;
CUTLASS_PRAGMA_UNROLL
for
(
int
i
=
0
;
i
<
FragmentSize
;
++
i
)
{
frg_row
[
i
]
=
tSR_rRow
(
epi_v
*
FragmentSize
+
i
);
}
return
frg_row
;
}
};
template
<
bool
ReferenceSrc
,
// do register tensors reference the src or dst layout of the tiled copy
class
...
Args
>
CUTLASS_DEVICE
auto
get_consumer_store_callbacks
(
ConsumerStoreArgs
<
Args
...
>
const
&
args
)
{
auto
[
M
,
N
,
K
,
L
]
=
args
.
problem_shape_mnkl
;
auto
[
m
,
n
,
k
,
l
]
=
args
.
tile_coord_mnkl
;
using
ThreadCount
=
decltype
(
size
(
args
.
tiled_copy
));
Tensor
mRow
=
make_tensor
(
make_gmem_ptr
(
params
.
ptr_row_array
[
l
]),
make_shape
(
M
,
N
,
1
),
params
.
dRow
);
Tensor
gRow
=
local_tile
(
mRow
(
_
,
_
,
l
),
take
<
0
,
2
>
(
args
.
tile_shape_mnk
),
make_coord
(
m
,
n
));
// (CTA_M, CTA_N)
Tensor
sRow
=
make_tensor
(
make_smem_ptr
(
smem
),
make_shape
(
size
<
0
>
(
CtaTileShapeMNK
{}),
size
<
1
>
(
CtaTileShapeMNK
{})),
make_shape
(
_0
{},
_1
{}));
// (CTA_M, CTA_N)
//// G2S: Gmem to Smem
auto
tiled_g2s
=
make_tiled_copy
(
Copy_Atom
<
DefaultCopy
,
Element
>
{},
Layout
<
Shape
<
_1
,
ThreadCount
>
,
Stride
<
_0
,
_1
>>
{},
Layout
<
_1
>
{});
auto
thr_g2s
=
tiled_g2s
.
get_slice
(
args
.
thread_idx
);
Tensor
tGS_gRow
=
thr_g2s
.
partition_S
(
gRow
);
Tensor
tGS_sRow
=
thr_g2s
.
partition_D
(
sRow
);
//// G2S: Coord
auto
cRow
=
make_identity_tensor
(
make_shape
(
size
<
0
>
(
CtaTileShapeMNK
{}),
size
<
1
>
(
CtaTileShapeMNK
{})));
Tensor
tGS_cRow
=
thr_g2s
.
partition_S
(
cRow
);
//// S2R: Smem to Reg
Tensor
tSR_sRow
=
sm90_partition_for_epilogue
<
ReferenceSrc
>
(
sRow
,
args
.
epi_tile
,
args
.
tiled_copy
,
args
.
thread_idx
);
Tensor
tSR_rRow
=
make_tensor_like
(
take
<
0
,
3
>
(
tSR_sRow
));
// (CPY,CPY_M,CPY_N)
return
ConsumerStoreCallbacks
<
decltype
(
tGS_gRow
),
decltype
(
tGS_sRow
),
decltype
(
tGS_cRow
),
decltype
(
tiled_g2s
),
decltype
(
tSR_sRow
),
decltype
(
tSR_rRow
),
decltype
(
args
.
tCcD
),
decltype
(
args
.
residue_cD
),
ThreadCount
>
(
tGS_gRow
,
tGS_sRow
,
tGS_cRow
,
tiled_g2s
,
tSR_sRow
,
tSR_rRow
,
args
.
tCcD
,
args
.
residue_cD
,
ThreadCount
{},
l
,
params
);
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
// Column vector broadcast
template
<
int
Stages
,
class
CtaTileShapeMNK
,
class
Element
,
class
StrideMNL
=
Stride
<
_1
,
_0
,
_0
>,
int
Alignment
=
128
/
sizeof_bits_v
<
Element
>
>
struct
Sm90ColOrScalarBroadcastArray
{
static_assert
(
Stages
==
0
,
"Column broadcast doesn't support smem usage yet"
);
static_assert
(
Alignment
*
sizeof_bits_v
<
Element
>
%
128
==
0
,
"sub-16B alignment not supported yet"
);
static_assert
(
(
cute
::
is_same_v
<
StrideMNL
,
Stride
<
_1
,
_0
,
_0
>>
)
||
// col vector broadcast, e.g. per-row alpha/bias
(
cute
::
is_same_v
<
StrideMNL
,
Stride
<
_1
,
_0
,
int
>>
));
// batched col vector broadcast, e.g. batched per-row bias
// Accumulator distributes col elements evenly amongst threads so we can just directly load from gmem
struct
SharedStorage
{
};
// This struct has been modified to have a bool indicating that ptr_col is a
// scalar that must be broadcast, instead of containing a scalar that is
// valid if ptr_col is null.
struct
Arguments
{
const
Element
*
const
*
ptr_col_array
=
nullptr
;
bool
col_broadcast
=
true
;
StrideMNL
dCol
=
{};
};
using
Params
=
Arguments
;
template
<
class
ProblemShape
>
static
constexpr
Params
to_underlying_arguments
(
ProblemShape
const
&
problem_shape
,
Arguments
const
&
args
,
void
*
workspace
)
{
return
args
;
}
template
<
class
ProblemShape
>
static
bool
can_implement
(
ProblemShape
const
&
problem_shape
,
Arguments
const
&
args
)
{
return
true
;
}
template
<
class
ProblemShape
>
static
size_t
get_workspace_size
(
ProblemShape
const
&
problem_shape
,
Arguments
const
&
args
)
{
return
0
;
}
template
<
class
ProblemShape
>
static
cutlass
::
Status
initialize_workspace
(
ProblemShape
const
&
problem_shape
,
Arguments
const
&
args
,
void
*
workspace
,
cudaStream_t
stream
,
CudaHostAdapter
*
cuda_adapter
=
nullptr
)
{
return
cutlass
::
Status
::
kSuccess
;
}
CUTLASS_DEVICE
bool
is_producer_load_needed
()
const
{
return
false
;
}
CUTLASS_DEVICE
bool
is_C_load_needed
()
const
{
return
false
;
}
CUTLASS_DEVICE
bool
is_zero
()
const
{
return
(
!
params
.
col_broadcast
&&
*
(
params
.
ptr_col_array
[
group
])
==
Element
(
0
));
}
CUTLASS_HOST_DEVICE
Sm90ColOrScalarBroadcastArray
()
{
}
CUTLASS_HOST_DEVICE
Sm90ColOrScalarBroadcastArray
(
Params
const
&
params
,
SharedStorage
const
&
shared_storage
)
:
params
(
params
)
{
}
Params
params
;
template
<
class
...
Args
>
CUTLASS_DEVICE
auto
get_producer_load_callbacks
(
ProducerLoadArgs
<
Args
...
>
const
&
args
)
{
return
EmptyProducerLoadCallbacks
{};
}
template
<
class
GTensor
,
class
RTensor
,
class
CTensor
,
class
ProblemShape
>
struct
ConsumerStoreCallbacks
:
EmptyConsumerStoreCallbacks
{
CUTLASS_DEVICE
ConsumerStoreCallbacks
(
GTensor
&&
tCgCol
,
RTensor
&&
tCrCol
,
CTensor
&&
tCcCol
,
ProblemShape
problem_shape
,
int
group
,
Params
const
&
params
)
:
tCgCol
(
cute
::
forward
<
GTensor
>
(
tCgCol
)),
tCrCol
(
cute
::
forward
<
RTensor
>
(
tCrCol
)),
tCcCol
(
cute
::
forward
<
CTensor
>
(
tCcCol
)),
m
(
get
<
0
>
(
problem_shape
)),
group
(
group
),
params
(
params
)
{}
GTensor
tCgCol
;
// (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
RTensor
tCrCol
;
CTensor
tCcCol
;
// (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
Params
const
&
params
;
int
m
;
int
group
;
CUTLASS_DEVICE
void
begin
()
{
Tensor
pred
=
make_tensor
<
bool
>
(
shape
(
tCgCol
));
CUTLASS_PRAGMA_UNROLL
for
(
int
i
=
0
;
i
<
size
(
pred
);
++
i
)
{
pred
(
i
)
=
get
<
0
>
(
tCcCol
(
i
))
<
m
;
}
if
(
!
params
.
col_broadcast
)
{
fill
(
tCrCol
,
*
(
params
.
ptr_col_array
[
group
]));
return
;
}
// Filter so we don't issue redundant copies over stride-0 modes
// (only works if 0-strides are in same location, which is by construction)
copy_if
(
pred
,
filter
(
tCgCol
),
filter
(
tCrCol
));
}
template
<
typename
ElementAccumulator
,
int
FragmentSize
>
CUTLASS_DEVICE
Array
<
Element
,
FragmentSize
>
visit
(
Array
<
ElementAccumulator
,
FragmentSize
>
const
&
frg_acc
,
int
epi_v
,
int
epi_m
,
int
epi_n
)
{
Array
<
Element
,
FragmentSize
>
frg_col
;
Tensor
tCrCol_mn
=
tCrCol
(
_
,
_
,
_
,
epi_m
,
epi_n
);
CUTLASS_PRAGMA_UNROLL
for
(
int
i
=
0
;
i
<
FragmentSize
;
++
i
)
{
frg_col
[
i
]
=
tCrCol_mn
(
epi_v
*
FragmentSize
+
i
);
}
return
frg_col
;
}
};
template
<
bool
ReferenceSrc
,
// do register tensors reference the src or dst layout of the tiled copy
class
...
Args
>
CUTLASS_DEVICE
auto
get_consumer_store_callbacks
(
ConsumerStoreArgs
<
Args
...
>
const
&
args
)
{
auto
[
M
,
N
,
K
,
L
]
=
args
.
problem_shape_mnkl
;
auto
[
m
,
n
,
k
,
l
]
=
args
.
tile_coord_mnkl
;
Tensor
mCol
=
make_tensor
(
make_gmem_ptr
(
params
.
ptr_col_array
[
l
]),
make_shape
(
M
,
N
,
1
),
params
.
dCol
);
Tensor
tCgCol
=
sm90_partition_for_epilogue
<
ReferenceSrc
>
(
// (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
mCol
,
args
.
tile_shape_mnk
,
args
.
tile_coord_mnkl
,
args
.
epi_tile
,
args
.
tiled_copy
,
args
.
thread_idx
);
Tensor
tCrCol
=
make_tensor_like
(
tCgCol
);
// (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
// Generate an identity tensor matching the shape of the global tensor and
// partition the same way, this will be used to generate the predicate
// tensor for loading
Tensor
cCol
=
make_identity_tensor
(
mCol
.
shape
());
Tensor
tCcCol
=
sm90_partition_for_epilogue
<
ReferenceSrc
>
(
// (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
cCol
,
args
.
tile_shape_mnk
,
args
.
tile_coord_mnkl
,
args
.
epi_tile
,
args
.
tiled_copy
,
args
.
thread_idx
);
return
ConsumerStoreCallbacks
(
cute
::
move
(
tCgCol
),
cute
::
move
(
tCrCol
),
cute
::
move
(
tCcCol
),
args
.
problem_shape_mnkl
,
l
,
params
);
}
};
}
csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp
View file @
fcfc474d
#pragma once
#pragma once
#include "cutlass_extensions/epilogue/broadcast_load_epilogue_c3x.hpp"
#include "cutlass_extensions/epilogue/broadcast_load_epilogue_c3x.hpp"
#include "cutlass_extensions/epilogue/broadcast_load_epilogue_array_c3x.hpp"
/*
/*
This file defines custom epilogues for fusing channel scales, token scales,
This file defines custom epilogues for fusing channel scales, token scales,
...
@@ -69,6 +70,16 @@ struct ScaledEpilogueBase {
...
@@ -69,6 +70,16 @@ struct ScaledEpilogueBase {
0
/*Stages*/
,
TileShape
,
T
,
T
,
Stride
<
Int
<
0
>
,
Int
<
1
>
,
Int
<
0
>>
,
0
/*Stages*/
,
TileShape
,
T
,
T
,
Stride
<
Int
<
0
>
,
Int
<
1
>
,
Int
<
0
>>
,
128
/
sizeof_bits_v
<
T
>
,
EnableNullPtr
>
;
128
/
sizeof_bits_v
<
T
>
,
EnableNullPtr
>
;
template
<
typename
T
>
using
ColOrScalarLoadArray
=
cutlass
::
epilogue
::
fusion
::
Sm90ColOrScalarBroadcastArray
<
0
/*Stages*/
,
TileShape
,
T
,
Stride
<
Int
<
1
>
,
Int
<
0
>
,
Int
<
0
>>>
;
template
<
typename
T
>
using
RowOrScalarLoadArray
=
cutlass
::
epilogue
::
fusion
::
Sm90RowOrScalarBroadcastArray
<
0
/*Stages*/
,
TileShape
,
T
,
Stride
<
Int
<
0
>
,
Int
<
1
>
,
Int
<
0
>>>
;
// This utility function constructs the arguments for the load descriptors
// This utility function constructs the arguments for the load descriptors
// from a tensor. It can handle both row and column, as well as row/column or
// from a tensor. It can handle both row and column, as well as row/column or
// scalar cases.
// scalar cases.
...
@@ -96,6 +107,14 @@ struct ScaledEpilogueBase {
...
@@ -96,6 +107,14 @@ struct ScaledEpilogueBase {
std
::
is_same_v
<
Descriptor
,
RowLoad
<
T
,
true
>>
);
std
::
is_same_v
<
Descriptor
,
RowLoad
<
T
,
true
>>
);
return
Arguments
{
data_ptr
};
return
Arguments
{
data_ptr
};
}
}
template
<
typename
Descriptor
,
typename
T
>
static
auto
args_from_tensor
(
const
T
*
const
*
data_ptr
,
bool
do_broadcast
)
{
using
Arguments
=
typename
Descriptor
::
Arguments
;
static_assert
(
std
::
is_same_v
<
Descriptor
,
ColOrScalarLoadArray
<
T
>>
||
std
::
is_same_v
<
Descriptor
,
RowOrScalarLoadArray
<
T
>>
);
return
Arguments
{
data_ptr
,
do_broadcast
};
}
};
};
/*
/*
...
@@ -381,4 +400,51 @@ struct ScaledEpilogueBiasAzpToken
...
@@ -381,4 +400,51 @@ struct ScaledEpilogueBiasAzpToken
}
}
};
};
/*
This epilogue works like ScaledEpilogue, but ScaleA and ScaleB are pointers
to arrays containing different scales used in group gemm. The number of
pointers in ScaleA and the number of pointers in ScaleB are equal to the
group size.
*/
template
<
typename
ElementAcc
,
typename
ElementD
,
typename
EpilogueDescriptor
>
struct
ScaledEpilogueArray
:
private
ScaledEpilogueBase
<
ElementAcc
,
ElementD
,
EpilogueDescriptor
>
{
private:
using
SUPER
=
ScaledEpilogueBase
<
ElementAcc
,
ElementD
,
EpilogueDescriptor
>
;
using
Accum
=
typename
SUPER
::
Accum
;
using
ScaleA
=
typename
SUPER
::
template
ColOrScalarLoadArray
<
float
>;
using
ScaleB
=
typename
SUPER
::
template
RowOrScalarLoadArray
<
float
>;
using
Compute0
=
cutlass
::
epilogue
::
fusion
::
Sm90Compute
<
cutlass
::
multiplies
,
float
,
float
,
cutlass
::
FloatRoundStyle
::
round_to_nearest
>
;
using
EVTCompute0
=
cutlass
::
epilogue
::
fusion
::
Sm90EVT
<
Compute0
,
ScaleB
,
Accum
>
;
using
Compute1
=
cutlass
::
epilogue
::
fusion
::
Sm90Compute
<
cutlass
::
multiplies
,
ElementD
,
float
,
cutlass
::
FloatRoundStyle
::
round_to_nearest
>
;
public:
using
EVTCompute
=
cutlass
::
epilogue
::
fusion
::
Sm90EVT
<
Compute1
,
ScaleA
,
EVTCompute0
>
;
using
ArgumentType
=
typename
EVTCompute
::
Arguments
;
using
ScaleAArray
=
typename
SUPER
::
template
ColOrScalarLoadArray
<
float
>;
using
ScaleBArray
=
typename
SUPER
::
template
RowOrScalarLoadArray
<
float
>;
static
ArgumentType
prepare_args
(
float
const
*
const
*
a_scales_ptr
,
float
const
*
const
*
b_scales_ptr
,
bool
a_col_broadcast
,
bool
b_row_broadcast
)
{
auto
a_args
=
SUPER
::
template
args_from_tensor
<
ScaleAArray
,
float
>(
a_scales_ptr
,
a_col_broadcast
);
auto
b_args
=
SUPER
::
template
args_from_tensor
<
ScaleBArray
,
float
>(
b_scales_ptr
,
b_row_broadcast
);
typename
EVTCompute0
::
Arguments
evt0_args
{
b_args
,
{},
{}};
return
ArgumentType
{
a_args
,
evt0_args
,
{}};
}
};
};
// namespace vllm::c3x
};
// namespace vllm::c3x
csrc/ops.h
View file @
fcfc474d
...
@@ -260,6 +260,8 @@ void advance_step_flashinfer(
...
@@ -260,6 +260,8 @@ void advance_step_flashinfer(
torch
::
Tensor
&
paged_kv_indices
,
torch
::
Tensor
&
paged_kv_indptr
,
torch
::
Tensor
&
paged_kv_indices
,
torch
::
Tensor
&
paged_kv_indptr
,
torch
::
Tensor
&
paged_kv_last_page_len
,
torch
::
Tensor
&
block_table_bounds
);
torch
::
Tensor
&
paged_kv_last_page_len
,
torch
::
Tensor
&
block_table_bounds
);
torch
::
Tensor
get_cuda_view_from_cpu_tensor
(
torch
::
Tensor
&
cpu_tensor
);
#ifndef USE_ROCM
#ifndef USE_ROCM
torch
::
Tensor
aqlm_gemm
(
const
torch
::
Tensor
&
input
,
const
torch
::
Tensor
&
codes
,
torch
::
Tensor
aqlm_gemm
(
const
torch
::
Tensor
&
input
,
const
torch
::
Tensor
&
codes
,
const
torch
::
Tensor
&
codebooks
,
const
torch
::
Tensor
&
codebooks
,
...
@@ -284,7 +286,8 @@ torch::Tensor permute_cols(torch::Tensor const& A, torch::Tensor const& perm);
...
@@ -284,7 +286,8 @@ torch::Tensor permute_cols(torch::Tensor const& A, torch::Tensor const& perm);
#endif
#endif
torch
::
Tensor
ggml_dequantize
(
torch
::
Tensor
W
,
int64_t
type
,
int64_t
m
,
torch
::
Tensor
ggml_dequantize
(
torch
::
Tensor
W
,
int64_t
type
,
int64_t
m
,
int64_t
n
);
int64_t
n
,
std
::
optional
<
at
::
ScalarType
>
const
&
dtype
);
torch
::
Tensor
ggml_mul_mat_vec_a8
(
torch
::
Tensor
W
,
torch
::
Tensor
X
,
torch
::
Tensor
ggml_mul_mat_vec_a8
(
torch
::
Tensor
W
,
torch
::
Tensor
X
,
int64_t
type
,
int64_t
row
);
int64_t
type
,
int64_t
row
);
...
@@ -305,6 +308,7 @@ int64_t ggml_moe_get_block_size(int64_t type);
...
@@ -305,6 +308,7 @@ int64_t ggml_moe_get_block_size(int64_t type);
bool
cutlass_scaled_mm_supports_fp4
(
int64_t
cuda_device_capability
);
bool
cutlass_scaled_mm_supports_fp4
(
int64_t
cuda_device_capability
);
bool
cutlass_scaled_mm_supports_fp8
(
int64_t
cuda_device_capability
);
bool
cutlass_scaled_mm_supports_fp8
(
int64_t
cuda_device_capability
);
bool
cutlass_scaled_mm_supports_block_fp8
(
int64_t
cuda_device_capability
);
bool
cutlass_scaled_mm_supports_block_fp8
(
int64_t
cuda_device_capability
);
bool
cutlass_group_gemm_supported
(
int64_t
cuda_device_capability
);
void
cutlass_scaled_fp4_mm
(
torch
::
Tensor
&
D
,
torch
::
Tensor
const
&
A
,
void
cutlass_scaled_fp4_mm
(
torch
::
Tensor
&
D
,
torch
::
Tensor
const
&
A
,
torch
::
Tensor
const
&
B
,
torch
::
Tensor
const
&
A_sf
,
torch
::
Tensor
const
&
B
,
torch
::
Tensor
const
&
A_sf
,
...
@@ -316,6 +320,19 @@ void cutlass_scaled_mm(torch::Tensor& out, torch::Tensor const& a,
...
@@ -316,6 +320,19 @@ void cutlass_scaled_mm(torch::Tensor& out, torch::Tensor const& a,
torch
::
Tensor
const
&
b_scales
,
torch
::
Tensor
const
&
b_scales
,
std
::
optional
<
torch
::
Tensor
>
const
&
bias
);
std
::
optional
<
torch
::
Tensor
>
const
&
bias
);
void
cutlass_moe_mm
(
torch
::
Tensor
&
out_tensors
,
torch
::
Tensor
const
&
a_tensors
,
torch
::
Tensor
const
&
b_tensors
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
,
torch
::
Tensor
const
&
expert_offsets
,
torch
::
Tensor
const
&
problem_sizes
,
torch
::
Tensor
const
&
a_strides
,
torch
::
Tensor
const
&
b_strides
,
torch
::
Tensor
const
&
c_strides
);
void
get_cutlass_moe_mm_data
(
const
torch
::
Tensor
&
topk_ids
,
torch
::
Tensor
&
expert_offsets
,
torch
::
Tensor
&
problem_sizes1
,
torch
::
Tensor
&
problem_sizes2
,
torch
::
Tensor
&
input_permutation
,
torch
::
Tensor
&
output_permutation
,
const
int64_t
num_experts
,
const
int64_t
n
,
const
int64_t
k
);
void
cutlass_scaled_mm_azp
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
void
cutlass_scaled_mm_azp
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
torch
::
Tensor
const
&
b
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
a_scales
,
...
@@ -394,7 +411,8 @@ void causal_conv1d_fwd(const at::Tensor& x, const at::Tensor& weight,
...
@@ -394,7 +411,8 @@ void causal_conv1d_fwd(const at::Tensor& x, const at::Tensor& weight,
using
fptr_t
=
int64_t
;
using
fptr_t
=
int64_t
;
fptr_t
init_custom_ar
(
const
std
::
vector
<
int64_t
>&
fake_ipc_ptrs
,
fptr_t
init_custom_ar
(
const
std
::
vector
<
int64_t
>&
fake_ipc_ptrs
,
torch
::
Tensor
&
rank_data
,
int64_t
rank
,
bool
full_nvlink
);
torch
::
Tensor
&
rank_data
,
int64_t
rank
,
bool
fully_connected
);
void
all_reduce
(
fptr_t
_fa
,
torch
::
Tensor
&
inp
,
torch
::
Tensor
&
out
,
void
all_reduce
(
fptr_t
_fa
,
torch
::
Tensor
&
inp
,
torch
::
Tensor
&
out
,
fptr_t
reg_buffer
,
int64_t
reg_buffer_sz_bytes
);
fptr_t
reg_buffer
,
int64_t
reg_buffer_sz_bytes
);
void
dispose
(
fptr_t
_fa
);
void
dispose
(
fptr_t
_fa
);
...
@@ -405,9 +423,7 @@ get_graph_buffer_ipc_meta(fptr_t _fa);
...
@@ -405,9 +423,7 @@ get_graph_buffer_ipc_meta(fptr_t _fa);
void
register_graph_buffers
(
fptr_t
_fa
,
void
register_graph_buffers
(
fptr_t
_fa
,
const
std
::
vector
<
std
::
vector
<
int64_t
>>&
handles
,
const
std
::
vector
<
std
::
vector
<
int64_t
>>&
handles
,
const
std
::
vector
<
std
::
vector
<
int64_t
>>&
offsets
);
const
std
::
vector
<
std
::
vector
<
int64_t
>>&
offsets
);
std
::
tuple
<
int64_t
,
torch
::
Tensor
>
allocate_shared_buffer_and_handle
(
std
::
tuple
<
int64_t
,
torch
::
Tensor
>
allocate_shared_buffer_and_handle
(
int64_t
size
);
int64_t
size
);
int64_t
open_mem_handle
(
torch
::
Tensor
&
mem_handle
);
int64_t
open_mem_handle
(
torch
::
Tensor
&
mem_handle
);
void
free_shared_buffer
(
int64_t
buffer
);
void
free_shared_buffer
(
int64_t
buffer
);
csrc/quantization/cutlass_w8a8/moe/get_group_starts.cuh
0 → 100644
View file @
fcfc474d
#pragma once
#include <cuda.h>
#include <torch/all.h>
#include <c10/cuda/CUDAStream.h>
#include "core/scalar_type.hpp"
#include "cutlass/bfloat16.h"
#include "cutlass/float8.h"
template
<
typename
ElementAB
,
typename
ElementC
,
typename
ElementAccumulator
>
__global__
void
get_group_gemm_starts
(
int32_t
*
expert_offsets
,
ElementAB
**
a_offsets
,
ElementAB
**
b_offsets
,
ElementC
**
out_offsets
,
ElementAccumulator
**
a_scales_offsets
,
ElementAccumulator
**
b_scales_offsets
,
ElementAB
*
a_base_as_int
,
ElementAB
*
b_base_as_int
,
ElementC
*
out_base_as_int
,
ElementAccumulator
*
a_scales_base_as_int
,
ElementAccumulator
*
b_scales_base_as_int
,
int64_t
n
,
int64_t
k
,
bool
per_act_token
,
bool
per_out_ch
)
{
int
expert_id
=
threadIdx
.
x
;
int64_t
expert_offset
=
expert_offsets
[
expert_id
];
a_offsets
[
expert_id
]
=
a_base_as_int
+
expert_offset
*
k
;
b_offsets
[
expert_id
]
=
b_base_as_int
+
expert_id
*
k
*
n
;
out_offsets
[
expert_id
]
=
out_base_as_int
+
expert_offset
*
n
;
a_scales_offsets
[
expert_id
]
=
a_scales_base_as_int
+
(
per_act_token
?
expert_offset
:
0
);
b_scales_offsets
[
expert_id
]
=
b_scales_base_as_int
+
(
per_out_ch
?
n
*
expert_id
:
expert_id
);
}
#define __CALL_GET_STARTS_KERNEL(TENSOR_C_TYPE, C_TYPE) \
else if (out_tensors.dtype() == TENSOR_C_TYPE) { \
get_group_gemm_starts<cutlass::float_e4m3_t, C_TYPE, float> \
<<<1, num_experts, 0, stream>>>( \
static_cast<int32_t*>(expert_offsets.data_ptr()), \
static_cast<cutlass::float_e4m3_t**>(a_ptrs.data_ptr()), \
static_cast<cutlass::float_e4m3_t**>(b_ptrs.data_ptr()), \
static_cast<C_TYPE**>(out_ptrs.data_ptr()), \
static_cast<float**>(a_scales_ptrs.data_ptr()), \
static_cast<float**>(b_scales_ptrs.data_ptr()), \
static_cast<cutlass::float_e4m3_t*>(a_tensors.data_ptr()), \
static_cast<cutlass::float_e4m3_t*>(b_tensors.data_ptr()), \
static_cast<C_TYPE*>(out_tensors.data_ptr()), \
static_cast<float*>(a_scales.data_ptr()), \
static_cast<float*>(b_scales.data_ptr()), out_tensors.size(1), \
a_tensors.size(1), per_act_token, per_out_ch); \
}
namespace
{
void
run_get_group_gemm_starts
(
torch
::
Tensor
const
&
expert_offsets
,
torch
::
Tensor
&
a_ptrs
,
torch
::
Tensor
&
b_ptrs
,
torch
::
Tensor
&
out_ptrs
,
torch
::
Tensor
&
a_scales_ptrs
,
torch
::
Tensor
&
b_scales_ptrs
,
torch
::
Tensor
const
&
a_tensors
,
torch
::
Tensor
const
&
b_tensors
,
torch
::
Tensor
&
out_tensors
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
)
{
TORCH_CHECK
(
a_tensors
.
dtype
()
==
torch
::
kFloat8_e4m3fn
);
TORCH_CHECK
(
b_tensors
.
dtype
()
==
torch
::
kFloat8_e4m3fn
);
TORCH_CHECK
(
a_scales
.
dtype
()
==
torch
::
kFloat32
);
TORCH_CHECK
(
b_scales
.
dtype
()
==
torch
::
kFloat32
);
int
num_experts
=
static_cast
<
int
>
(
expert_offsets
.
size
(
0
));
bool
per_act_token
=
a_scales
.
numel
()
!=
1
;
bool
per_out_ch
=
b_scales
.
numel
()
!=
num_experts
;
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
(
a_tensors
.
device
().
index
());
if
(
false
)
{
}
__CALL_GET_STARTS_KERNEL
(
torch
::
kBFloat16
,
cutlass
::
bfloat16_t
)
__CALL_GET_STARTS_KERNEL
(
torch
::
kFloat16
,
half
)
else
{
TORCH_CHECK
(
false
,
"Invalid output type (must be float16 or bfloat16)"
);
}
}
}
// namespace
\ No newline at end of file
csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x.cu
0 → 100644
View file @
fcfc474d
#include <cudaTypedefs.h>
#include <c10/cuda/CUDAGuard.h>
#include <torch/all.h>
#include "cutlass/cutlass.h"
#include "grouped_mm_c3x.cuh"
using
namespace
cute
;
namespace
{
template
<
typename
InType
,
typename
OutType
,
template
<
typename
,
typename
,
typename
>
typename
Epilogue
>
struct
sm90_fp8_config_default
{
// M in (16, inf)
static_assert
(
std
::
is_same
<
InType
,
cutlass
::
float_e4m3_t
>
());
using
KernelSchedule
=
cutlass
::
gemm
::
KernelPtrArrayTmaWarpSpecializedPingpongFP8FastAccum
;
using
EpilogueSchedule
=
cutlass
::
epilogue
::
PtrArrayTmaWarpSpecializedPingpong
;
using
TileShape
=
cute
::
Shape
<
cute
::
_64
,
cute
::
_256
,
cute
::
_128
>
;
using
ClusterShape
=
cute
::
Shape
<
cute
::
_1
,
cute
::
_2
,
cute
::
_1
>
;
using
Cutlass3xGemm
=
cutlass_3x_group_gemm
<
InType
,
OutType
,
Epilogue
,
TileShape
,
ClusterShape
,
KernelSchedule
,
EpilogueSchedule
>
;
};
template
<
typename
InType
,
typename
OutType
,
template
<
typename
,
typename
,
typename
>
typename
Epilogue
>
struct
sm90_fp8_config_M16
{
// M in [1, 16]
static_assert
(
std
::
is_same
<
InType
,
cutlass
::
float_e4m3_t
>
());
using
KernelSchedule
=
cutlass
::
gemm
::
KernelPtrArrayTmaWarpSpecializedPingpongFP8FastAccum
;
using
EpilogueSchedule
=
cutlass
::
epilogue
::
PtrArrayTmaWarpSpecializedPingpong
;
using
TileShape
=
cute
::
Shape
<
cute
::
_64
,
cute
::
_64
,
cute
::
_128
>
;
using
ClusterShape
=
cute
::
Shape
<
cute
::
_1
,
cute
::
_4
,
cute
::
_1
>
;
using
Cutlass3xGemm
=
cutlass_3x_group_gemm
<
InType
,
OutType
,
Epilogue
,
TileShape
,
ClusterShape
,
KernelSchedule
,
EpilogueSchedule
>
;
};
template
<
typename
InType
,
typename
OutType
,
template
<
typename
,
typename
,
typename
>
typename
Epilogue
>
struct
sm90_fp8_config_K8192
{
// K in [8192, inf)
static_assert
(
std
::
is_same
<
InType
,
cutlass
::
float_e4m3_t
>
());
using
KernelSchedule
=
cutlass
::
gemm
::
KernelPtrArrayTmaWarpSpecializedPingpongFP8FastAccum
;
using
EpilogueSchedule
=
cutlass
::
epilogue
::
PtrArrayTmaWarpSpecializedPingpong
;
using
TileShape
=
cute
::
Shape
<
cute
::
_128
,
cute
::
_128
,
cute
::
_128
>
;
using
ClusterShape
=
cute
::
Shape
<
cute
::
_1
,
cute
::
_8
,
cute
::
_1
>
;
using
Cutlass3xGemm
=
cutlass_3x_group_gemm
<
InType
,
OutType
,
Epilogue
,
TileShape
,
ClusterShape
,
KernelSchedule
,
EpilogueSchedule
>
;
};
template
<
typename
InType
,
typename
OutType
,
template
<
typename
,
typename
,
typename
>
typename
Epilogue
>
struct
sm90_fp8_config_N8192
{
// N in [8192, inf)
static_assert
(
std
::
is_same
<
InType
,
cutlass
::
float_e4m3_t
>
());
using
KernelSchedule
=
cutlass
::
gemm
::
KernelPtrArrayTmaWarpSpecializedPingpongFP8FastAccum
;
using
EpilogueSchedule
=
cutlass
::
epilogue
::
PtrArrayTmaWarpSpecializedPingpong
;
using
TileShape
=
cute
::
Shape
<
cute
::
_64
,
cute
::
_128
,
cute
::
_256
>
;
using
ClusterShape
=
cute
::
Shape
<
cute
::
_1
,
cute
::
_8
,
cute
::
_1
>
;
using
Cutlass3xGemm
=
cutlass_3x_group_gemm
<
InType
,
OutType
,
Epilogue
,
TileShape
,
ClusterShape
,
KernelSchedule
,
EpilogueSchedule
>
;
};
template
<
typename
InType
,
typename
OutType
>
void
run_cutlass_moe_mm_sm90
(
torch
::
Tensor
&
out_tensors
,
torch
::
Tensor
const
&
a_tensors
,
torch
::
Tensor
const
&
b_tensors
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
,
torch
::
Tensor
const
&
expert_offsets
,
torch
::
Tensor
const
&
problem_sizes
,
torch
::
Tensor
const
&
a_strides
,
torch
::
Tensor
const
&
b_strides
,
torch
::
Tensor
const
&
c_strides
)
{
TORCH_CHECK
(
a_tensors
.
size
(
0
)
>
0
,
"No input A tensors provided."
);
TORCH_CHECK
(
b_tensors
.
size
(
0
)
>
0
,
"No input B tensors provided."
);
TORCH_CHECK
(
out_tensors
.
size
(
0
)
>
0
,
"No output tensors provided."
);
TORCH_CHECK
(
a_tensors
.
dtype
()
==
torch
::
kFloat8_e4m3fn
,
"A tensors must be of type float8_e4m3fn."
);
TORCH_CHECK
(
b_tensors
.
dtype
()
==
torch
::
kFloat8_e4m3fn
,
"B tensors must be of type float8_e4m3fn."
);
TORCH_CHECK
(
a_tensors
.
dtype
()
==
torch
::
kFloat8_e4m3fn
);
TORCH_CHECK
(
b_tensors
.
dtype
()
==
torch
::
kFloat8_e4m3fn
);
using
Cutlass3xGemmN8192
=
typename
sm90_fp8_config_N8192
<
InType
,
OutType
,
vllm
::
c3x
::
ScaledEpilogueArray
>::
Cutlass3xGemm
;
using
Cutlass3xGemmK8192
=
typename
sm90_fp8_config_K8192
<
InType
,
OutType
,
vllm
::
c3x
::
ScaledEpilogueArray
>::
Cutlass3xGemm
;
using
Cutlass3xGemmM16
=
typename
sm90_fp8_config_M16
<
InType
,
OutType
,
vllm
::
c3x
::
ScaledEpilogueArray
>::
Cutlass3xGemm
;
using
Cutlass3xGemmDefault
=
typename
sm90_fp8_config_default
<
InType
,
OutType
,
vllm
::
c3x
::
ScaledEpilogueArray
>::
Cutlass3xGemm
;
uint32_t
const
m
=
a_tensors
.
size
(
0
);
uint32_t
const
n
=
out_tensors
.
size
(
1
);
uint32_t
const
k
=
a_tensors
.
size
(
1
);
if
(
n
>=
8192
)
{
cutlass_group_gemm_caller
<
Cutlass3xGemmN8192
>
(
out_tensors
,
a_tensors
,
b_tensors
,
a_scales
,
b_scales
,
expert_offsets
,
problem_sizes
,
a_strides
,
b_strides
,
c_strides
);
}
else
if
(
k
>=
8192
)
{
cutlass_group_gemm_caller
<
Cutlass3xGemmK8192
>
(
out_tensors
,
a_tensors
,
b_tensors
,
a_scales
,
b_scales
,
expert_offsets
,
problem_sizes
,
a_strides
,
b_strides
,
c_strides
);
}
else
if
(
m
<=
16
)
{
cutlass_group_gemm_caller
<
Cutlass3xGemmM16
>
(
out_tensors
,
a_tensors
,
b_tensors
,
a_scales
,
b_scales
,
expert_offsets
,
problem_sizes
,
a_strides
,
b_strides
,
c_strides
);
}
else
{
cutlass_group_gemm_caller
<
Cutlass3xGemmDefault
>
(
out_tensors
,
a_tensors
,
b_tensors
,
a_scales
,
b_scales
,
expert_offsets
,
problem_sizes
,
a_strides
,
b_strides
,
c_strides
);
}
}
void
dispatch_moe_mm_sm90
(
torch
::
Tensor
&
out_tensors
,
torch
::
Tensor
const
&
a_tensors
,
torch
::
Tensor
const
&
b_tensors
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
,
torch
::
Tensor
const
&
expert_offsets
,
torch
::
Tensor
const
&
problem_sizes
,
torch
::
Tensor
const
&
a_strides
,
torch
::
Tensor
const
&
b_strides
,
torch
::
Tensor
const
&
c_strides
)
{
if
(
out_tensors
.
dtype
()
==
torch
::
kBFloat16
)
{
run_cutlass_moe_mm_sm90
<
cutlass
::
float_e4m3_t
,
cutlass
::
bfloat16_t
>
(
out_tensors
,
a_tensors
,
b_tensors
,
a_scales
,
b_scales
,
expert_offsets
,
problem_sizes
,
a_strides
,
b_strides
,
c_strides
);
}
else
{
run_cutlass_moe_mm_sm90
<
cutlass
::
float_e4m3_t
,
cutlass
::
half_t
>
(
out_tensors
,
a_tensors
,
b_tensors
,
a_scales
,
b_scales
,
expert_offsets
,
problem_sizes
,
a_strides
,
b_strides
,
c_strides
);
}
}
}
// namespace
void
cutlass_moe_mm_sm90
(
torch
::
Tensor
&
out_tensors
,
torch
::
Tensor
const
&
a_tensors
,
torch
::
Tensor
const
&
b_tensors
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
,
torch
::
Tensor
const
&
expert_offsets
,
torch
::
Tensor
const
&
problem_sizes
,
torch
::
Tensor
const
&
a_strides
,
torch
::
Tensor
const
&
b_strides
,
torch
::
Tensor
const
&
c_strides
)
{
dispatch_moe_mm_sm90
(
out_tensors
,
a_tensors
,
b_tensors
,
a_scales
,
b_scales
,
expert_offsets
,
problem_sizes
,
a_strides
,
b_strides
,
c_strides
);
}
csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x.cuh
0 → 100644
View file @
fcfc474d
#pragma once
#include "cutlass/cutlass.h"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "cutlass/epilogue/collective/collective_builder.hpp"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp"
#include "cutlass_extensions/common.hpp"
#include "get_group_starts.cuh"
using
namespace
cute
;
namespace
{
using
ProblemShape
=
cutlass
::
gemm
::
GroupProblemShape
<
cute
::
Shape
<
int
,
int
,
int
>>
;
using
ElementAccumulator
=
float
;
using
ArchTag
=
cutlass
::
arch
::
Sm90
;
using
OperatorClass
=
cutlass
::
arch
::
OpClassTensorOp
;
using
LayoutA
=
cutlass
::
layout
::
RowMajor
;
using
LayoutB
=
cutlass
::
layout
::
ColumnMajor
;
using
LayoutC
=
cutlass
::
layout
::
RowMajor
;
template
<
typename
ElementAB_
,
typename
ElementC_
,
template
<
typename
,
typename
,
typename
>
typename
Epilogue_
,
typename
TileShape
,
typename
ClusterShape
,
typename
KernelSchedule
,
typename
EpilogueSchedule
>
struct
cutlass_3x_group_gemm
{
using
ElementAB
=
ElementAB_
;
using
ElementC
=
void
;
using
ElementD
=
ElementC_
;
using
ElementAccumulator
=
float
;
using
Epilogue
=
Epilogue_
<
ElementAccumulator
,
ElementD
,
TileShape
>
;
using
StrideC
=
cute
::
remove_pointer_t
<
cute
::
Stride
<
int64_t
,
cute
::
Int
<
1
>
,
cute
::
Int
<
0
>>>
;
static
constexpr
int
AlignmentAB
=
128
/
cutlass
::
sizeof_bits
<
ElementAB
>::
value
;
static
constexpr
int
AlignmentC
=
128
/
cutlass
::
sizeof_bits
<
ElementD
>::
value
;
using
EVTCompute
=
typename
Epilogue
::
EVTCompute
;
using
CollectiveEpilogue
=
typename
cutlass
::
epilogue
::
collective
::
CollectiveBuilder
<
ArchTag
,
OperatorClass
,
TileShape
,
ClusterShape
,
cutlass
::
epilogue
::
collective
::
EpilogueTileAuto
,
ElementAccumulator
,
ElementAccumulator
,
ElementC
,
LayoutC
*
,
AlignmentC
,
ElementD
,
LayoutC
*
,
AlignmentC
,
EpilogueSchedule
,
EVTCompute
>::
CollectiveOp
;
static
constexpr
size_t
CEStorageSize
=
sizeof
(
typename
CollectiveEpilogue
::
SharedStorage
);
using
Stages
=
typename
cutlass
::
gemm
::
collective
::
StageCountAutoCarveout
<
static_cast
<
int
>
(
CEStorageSize
)
>
;
using
CollectiveMainloop
=
typename
cutlass
::
gemm
::
collective
::
CollectiveBuilder
<
ArchTag
,
OperatorClass
,
ElementAB
,
LayoutA
*
,
AlignmentAB
,
ElementAB
,
LayoutB
*
,
AlignmentAB
,
ElementAccumulator
,
TileShape
,
ClusterShape
,
Stages
,
KernelSchedule
>::
CollectiveOp
;
using
KernelType
=
enable_sm90_only
<
cutlass
::
gemm
::
kernel
::
GemmUniversal
<
ProblemShape
,
CollectiveMainloop
,
CollectiveEpilogue
>>
;
struct
GemmKernel
:
public
KernelType
{};
};
template
<
typename
Gemm
>
void
cutlass_group_gemm_caller
(
torch
::
Tensor
&
out_tensors
,
torch
::
Tensor
const
&
a_tensors
,
torch
::
Tensor
const
&
b_tensors
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
,
torch
::
Tensor
const
&
expert_offsets
,
torch
::
Tensor
const
&
problem_sizes
,
torch
::
Tensor
const
&
a_strides
,
torch
::
Tensor
const
&
b_strides
,
torch
::
Tensor
const
&
c_strides
)
{
using
ElementAB
=
typename
Gemm
::
ElementAB
;
using
ElementD
=
typename
Gemm
::
ElementD
;
int
num_experts
=
static_cast
<
int
>
(
expert_offsets
.
size
(
0
));
int
k_size
=
a_tensors
.
size
(
1
);
int
n_size
=
out_tensors
.
size
(
1
);
bool
per_act_token
=
a_scales
.
numel
()
!=
1
;
bool
per_out_ch
=
b_scales
.
numel
()
!=
num_experts
;
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
(
a_tensors
.
device
().
index
());
auto
options_int
=
torch
::
TensorOptions
().
dtype
(
torch
::
kInt64
).
device
(
a_tensors
.
device
());
torch
::
Tensor
a_ptrs
=
torch
::
empty
(
num_experts
,
options_int
);
torch
::
Tensor
b_ptrs
=
torch
::
empty
(
num_experts
,
options_int
);
torch
::
Tensor
out_ptrs
=
torch
::
empty
(
num_experts
,
options_int
);
torch
::
Tensor
a_scales_ptrs
=
torch
::
empty
(
num_experts
,
options_int
);
torch
::
Tensor
b_scales_ptrs
=
torch
::
empty
(
num_experts
,
options_int
);
run_get_group_gemm_starts
(
expert_offsets
,
a_ptrs
,
b_ptrs
,
out_ptrs
,
a_scales_ptrs
,
b_scales_ptrs
,
a_tensors
,
b_tensors
,
out_tensors
,
a_scales
,
b_scales
);
using
GemmKernel
=
typename
Gemm
::
GemmKernel
;
using
StrideA
=
Stride
<
int64_t
,
Int
<
1
>
,
Int
<
0
>>
;
using
StrideB
=
Stride
<
int64_t
,
Int
<
1
>
,
Int
<
0
>>
;
using
StrideC
=
typename
GemmKernel
::
InternalStrideC
;
ProblemShape
::
UnderlyingProblemShape
*
problem_sizes_as_shapes
=
static_cast
<
ProblemShape
::
UnderlyingProblemShape
*>
(
problem_sizes
.
data_ptr
());
ProblemShape
prob_shape
{
num_experts
,
problem_sizes_as_shapes
,
nullptr
};
typename
GemmKernel
::
MainloopArguments
mainloop_args
{
static_cast
<
const
ElementAB
**>
(
a_ptrs
.
data_ptr
()),
static_cast
<
StrideA
*>
(
a_strides
.
data_ptr
()),
static_cast
<
const
ElementAB
**>
(
b_ptrs
.
data_ptr
()),
static_cast
<
StrideB
*>
(
b_strides
.
data_ptr
())};
// Currently, we are only able to do broadcast on either all or none a_scales
// and on either all or none b_scales
typename
GemmKernel
::
EpilogueArguments
epilogue_args
{
Gemm
::
Epilogue
::
prepare_args
(
static_cast
<
const
ElementAccumulator
**>
(
a_scales_ptrs
.
data_ptr
()),
static_cast
<
const
ElementAccumulator
**>
(
b_scales_ptrs
.
data_ptr
()),
per_act_token
,
per_out_ch
),
nullptr
,
static_cast
<
StrideC
*>
(
c_strides
.
data_ptr
()),
static_cast
<
ElementD
**>
(
out_ptrs
.
data_ptr
()),
static_cast
<
StrideC
*>
(
c_strides
.
data_ptr
())};
typename
GemmKernel
::
Arguments
args
{
cutlass
::
gemm
::
GemmUniversalMode
::
kGrouped
,
prob_shape
,
mainloop_args
,
epilogue_args
};
using
GemmOp
=
cutlass
::
gemm
::
device
::
GemmUniversalAdapter
<
GemmKernel
>
;
GemmOp
gemm_op
;
CUTLASS_CHECK
(
gemm_op
.
can_implement
(
args
));
size_t
workspace_size
=
gemm_op
.
get_workspace_size
(
args
);
auto
const
workspace_options
=
torch
::
TensorOptions
().
dtype
(
torch
::
kUInt8
).
device
(
a_tensors
.
device
());
auto
workspace
=
torch
::
empty
(
workspace_size
,
workspace_options
);
cutlass
::
Status
status
=
gemm_op
.
run
(
args
,
workspace
.
data_ptr
(),
stream
);
CUTLASS_CHECK
(
status
);
}
}
// namespace
csrc/quantization/cutlass_w8a8/moe/moe_data.cu
0 → 100644
View file @
fcfc474d
#include <cudaTypedefs.h>
#include <c10/cuda/CUDAGuard.h>
#include <torch/all.h>
#include <iostream>
constexpr
uint64_t
THREADS_PER_EXPERT
=
512
;
__global__
void
compute_problem_sizes
(
const
int
*
__restrict__
topk_ids
,
int32_t
*
problem_sizes1
,
int32_t
*
problem_sizes2
,
int32_t
*
atomic_buffer
,
const
int
topk_length
,
const
int
n
,
const
int
k
)
{
int
expert_id
=
blockIdx
.
x
;
int
occurrences
=
0
;
for
(
int
i
=
threadIdx
.
x
;
i
<
topk_length
;
i
+=
THREADS_PER_EXPERT
)
{
occurrences
+=
(
topk_ids
[
i
]
==
expert_id
);
}
atomicAdd
(
&
atomic_buffer
[
expert_id
],
occurrences
);
__syncthreads
();
if
(
threadIdx
.
x
==
0
)
{
int
final_occurrences
=
atomic_buffer
[
expert_id
];
problem_sizes1
[
expert_id
*
3
]
=
final_occurrences
;
problem_sizes1
[
expert_id
*
3
+
1
]
=
2
*
n
;
problem_sizes1
[
expert_id
*
3
+
2
]
=
k
;
problem_sizes2
[
expert_id
*
3
]
=
final_occurrences
;
problem_sizes2
[
expert_id
*
3
+
1
]
=
k
;
problem_sizes2
[
expert_id
*
3
+
2
]
=
n
;
}
}
__global__
void
compute_expert_offsets
(
const
int32_t
*
__restrict__
problem_sizes1
,
int32_t
*
expert_offsets
,
int32_t
*
atomic_buffer
,
const
int
num_experts
)
{
int32_t
tot_offset
=
0
;
expert_offsets
[
0
]
=
0
;
for
(
int
i
=
0
;
i
<
num_experts
;
++
i
)
{
atomic_buffer
[
i
]
=
tot_offset
;
tot_offset
+=
problem_sizes1
[
i
*
3
];
expert_offsets
[
i
+
1
]
=
tot_offset
;
}
}
__global__
void
compute_arg_sorts
(
const
int
*
__restrict__
topk_ids
,
int32_t
*
input_permutation
,
int32_t
*
output_permutation
,
int32_t
*
atomic_buffer
,
const
int
topk_length
,
const
int
topk
)
{
int
expert_id
=
blockIdx
.
x
;
for
(
int
i
=
threadIdx
.
x
;
i
<
topk_length
;
i
+=
THREADS_PER_EXPERT
)
{
if
(
topk_ids
[
i
]
==
expert_id
)
{
int
start
=
atomicAdd
(
&
atomic_buffer
[
expert_id
],
1
);
input_permutation
[
start
]
=
i
/
topk
;
output_permutation
[
i
]
=
start
;
}
}
}
void
get_cutlass_moe_mm_data_caller
(
const
torch
::
Tensor
&
topk_ids
,
torch
::
Tensor
&
expert_offsets
,
torch
::
Tensor
&
problem_sizes1
,
torch
::
Tensor
&
problem_sizes2
,
torch
::
Tensor
&
input_permutation
,
torch
::
Tensor
&
output_permutation
,
const
int64_t
num_experts
,
const
int64_t
n
,
const
int64_t
k
)
{
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
(
topk_ids
.
device
().
index
());
auto
options_int32
=
torch
::
TensorOptions
().
dtype
(
torch
::
kInt32
).
device
(
topk_ids
.
device
());
torch
::
Tensor
atomic_buffer
=
torch
::
zeros
(
num_experts
,
options_int32
);
int
num_threads
=
min
(
THREADS_PER_EXPERT
,
topk_ids
.
numel
());
compute_problem_sizes
<<<
num_experts
,
num_threads
,
0
,
stream
>>>
(
static_cast
<
const
int32_t
*>
(
topk_ids
.
data_ptr
()),
static_cast
<
int32_t
*>
(
problem_sizes1
.
data_ptr
()),
static_cast
<
int32_t
*>
(
problem_sizes2
.
data_ptr
()),
static_cast
<
int32_t
*>
(
atomic_buffer
.
data_ptr
()),
topk_ids
.
numel
(),
n
,
k
);
compute_expert_offsets
<<<
1
,
1
,
0
,
stream
>>>
(
static_cast
<
const
int32_t
*>
(
problem_sizes1
.
data_ptr
()),
static_cast
<
int32_t
*>
(
expert_offsets
.
data_ptr
()),
static_cast
<
int32_t
*>
(
atomic_buffer
.
data_ptr
()),
num_experts
);
compute_arg_sorts
<<<
num_experts
,
num_threads
,
0
,
stream
>>>
(
static_cast
<
const
int32_t
*>
(
topk_ids
.
data_ptr
()),
static_cast
<
int32_t
*>
(
input_permutation
.
data_ptr
()),
static_cast
<
int32_t
*>
(
output_permutation
.
data_ptr
()),
static_cast
<
int32_t
*>
(
atomic_buffer
.
data_ptr
()),
topk_ids
.
numel
(),
topk_ids
.
size
(
1
));
}
csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu
View file @
fcfc474d
...
@@ -29,6 +29,20 @@ void cutlass_scaled_mm_sm90(torch::Tensor& c, torch::Tensor const& a,
...
@@ -29,6 +29,20 @@ void cutlass_scaled_mm_sm90(torch::Tensor& c, torch::Tensor const& a,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
,
torch
::
Tensor
const
&
b_scales
,
std
::
optional
<
torch
::
Tensor
>
const
&
bias
);
std
::
optional
<
torch
::
Tensor
>
const
&
bias
);
void
cutlass_moe_mm_sm90
(
torch
::
Tensor
&
out_tensors
,
torch
::
Tensor
const
&
a_tensors
,
torch
::
Tensor
const
&
b_tensors
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
,
torch
::
Tensor
const
&
expert_offsets
,
torch
::
Tensor
const
&
problem_sizes
,
torch
::
Tensor
const
&
a_strides
,
torch
::
Tensor
const
&
b_strides
,
torch
::
Tensor
const
&
c_strides
);
void
get_cutlass_moe_mm_data_caller
(
const
torch
::
Tensor
&
topk_ids
,
torch
::
Tensor
&
expert_offsets
,
torch
::
Tensor
&
problem_sizes1
,
torch
::
Tensor
&
problem_sizes2
,
torch
::
Tensor
&
input_permutation
,
torch
::
Tensor
&
output_permutation
,
const
int64_t
num_experts
,
const
int64_t
n
,
const
int64_t
k
);
#endif
#endif
#if defined ENABLE_SCALED_MM_SM100 && ENABLE_SCALED_MM_SM100
#if defined ENABLE_SCALED_MM_SM100 && ENABLE_SCALED_MM_SM100
...
@@ -102,6 +116,19 @@ bool cutlass_scaled_mm_supports_block_fp8(int64_t cuda_device_capability) {
...
@@ -102,6 +116,19 @@ bool cutlass_scaled_mm_supports_block_fp8(int64_t cuda_device_capability) {
return
false
;
return
false
;
}
}
bool
cutlass_group_gemm_supported
(
int64_t
cuda_device_capability
)
{
// CUTLASS groped FP8 kernels need at least CUDA 12.3
// and SM90 (Hopper)
#if defined CUDA_VERSION
if
(
cuda_device_capability
==
90
)
{
return
CUDA_VERSION
>=
12030
;
}
#endif
return
false
;
}
void
cutlass_scaled_mm
(
torch
::
Tensor
&
c
,
torch
::
Tensor
const
&
a
,
void
cutlass_scaled_mm
(
torch
::
Tensor
&
c
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
,
torch
::
Tensor
const
&
b_scales
,
...
@@ -168,6 +195,46 @@ void cutlass_scaled_mm(torch::Tensor& c, torch::Tensor const& a,
...
@@ -168,6 +195,46 @@ void cutlass_scaled_mm(torch::Tensor& c, torch::Tensor const& a,
version_num
);
version_num
);
}
}
void
cutlass_moe_mm
(
torch
::
Tensor
&
out_tensors
,
torch
::
Tensor
const
&
a_tensors
,
torch
::
Tensor
const
&
b_tensors
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
,
torch
::
Tensor
const
&
expert_offsets
,
torch
::
Tensor
const
&
problem_sizes
,
torch
::
Tensor
const
&
a_strides
,
torch
::
Tensor
const
&
b_strides
,
torch
::
Tensor
const
&
c_strides
)
{
int32_t
version_num
=
get_sm_version_num
();
#if defined ENABLE_CUTLASS_MOE_SM90 && ENABLE_CUTLASS_MOE_SM90
cutlass_moe_mm_sm90
(
out_tensors
,
a_tensors
,
b_tensors
,
a_scales
,
b_scales
,
expert_offsets
,
problem_sizes
,
a_strides
,
b_strides
,
c_strides
);
return
;
#endif
TORCH_CHECK_NOT_IMPLEMENTED
(
false
,
"No compiled cutlass_scaled_mm for CUDA device capability: "
,
version_num
,
". Required capability: 90"
);
}
void
get_cutlass_moe_mm_data
(
const
torch
::
Tensor
&
topk_ids
,
torch
::
Tensor
&
expert_offsets
,
torch
::
Tensor
&
problem_sizes1
,
torch
::
Tensor
&
problem_sizes2
,
torch
::
Tensor
&
input_permutation
,
torch
::
Tensor
&
output_permutation
,
const
int64_t
num_experts
,
const
int64_t
n
,
const
int64_t
k
)
{
// This function currently gets compiled only if we have a valid cutlass moe
// mm to run it for.
int32_t
version_num
=
get_sm_version_num
();
#if defined ENABLE_CUTLASS_MOE_SM90 && ENABLE_CUTLASS_MOE_SM90
get_cutlass_moe_mm_data_caller
(
topk_ids
,
expert_offsets
,
problem_sizes1
,
problem_sizes2
,
input_permutation
,
output_permutation
,
num_experts
,
n
,
k
);
return
;
#endif
TORCH_CHECK_NOT_IMPLEMENTED
(
false
,
"No compiled get_cutlass_moe_mm_data: no cutlass_scaled_mm kernel for "
"CUDA device capability: "
,
version_num
,
". Required capability: 90"
);
}
void
cutlass_scaled_mm_azp
(
torch
::
Tensor
&
c
,
torch
::
Tensor
const
&
a
,
void
cutlass_scaled_mm_azp
(
torch
::
Tensor
&
c
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
torch
::
Tensor
const
&
b
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
a_scales
,
...
...
csrc/quantization/fp8/common.cu
View file @
fcfc474d
...
@@ -30,9 +30,6 @@ __global__ void dynamic_per_token_scaled_fp8_quant_kernel(
...
@@ -30,9 +30,6 @@ __global__ void dynamic_per_token_scaled_fp8_quant_kernel(
fp8_type
*
__restrict__
out
,
float
*
__restrict__
scale
,
fp8_type
*
__restrict__
out
,
float
*
__restrict__
scale
,
scalar_t
const
*
__restrict__
input
,
float
const
*
__restrict__
scale_ub
,
scalar_t
const
*
__restrict__
input
,
float
const
*
__restrict__
scale_ub
,
const
int
hidden_size
)
{
const
int
hidden_size
)
{
float
const
min_scaling_factor
=
1.0
f
/
(
fp8_e4m3_adjusted_max_v
<
fp8_type
>
*
512.
f
);
int
const
tid
=
threadIdx
.
x
;
int
const
tid
=
threadIdx
.
x
;
int
const
token_idx
=
blockIdx
.
x
;
int
const
token_idx
=
blockIdx
.
x
;
...
@@ -67,8 +64,8 @@ __global__ void dynamic_per_token_scaled_fp8_quant_kernel(
...
@@ -67,8 +64,8 @@ __global__ void dynamic_per_token_scaled_fp8_quant_kernel(
token_scale
=
block_absmax_val_maybe
;
token_scale
=
block_absmax_val_maybe
;
}
}
// token scale computation
// token scale computation
token_scale
=
max
(
token_scale
/
fp8_e4m3_adjusted
_max_v
<
fp8_type
>
,
token_scale
=
max
(
token_scale
/
quant_type
_max_v
<
fp8_type
>
,
min_scaling_factor
);
min_scaling_factor
<
fp8_type
>::
val
()
);
scale
[
token_idx
]
=
token_scale
;
scale
[
token_idx
]
=
token_scale
;
}
}
__syncthreads
();
__syncthreads
();
...
...
csrc/quantization/fp8/common.cuh
View file @
fcfc474d
#pragma once
#pragma once
#include "quantization/vectorization.cuh"
#include "quantization/vectorization.cuh"
#include "quantization/utils.cuh"
#include <cmath>
#include <cmath>
#include <c10/core/ScalarType.h>
#ifndef USE_ROCM
#ifdef USE_ROCM
#include <c10/util/Float8_e4m3fn.h>
#define MAYBE_HOST_DEVICE C10_HOST_DEVICE
#else
#include <ATen/hip/HIPContext.h>
#include <c10/util/Float8_e4m3fn.h>
#include <c10/util/Float8_e4m3fnuz.h>
#include "amd/quant_utils.cuh"
#include "amd/quant_utils.cuh"
// ROCm doesn't seem to need C10_HOST_DEVICE for static constexpr
#define MAYBE_HOST_DEVICE
#endif
#endif
// Determines the preferred FP8 type for the current platform.
// Determines the preferred FP8 type for the current platform.
...
@@ -31,29 +23,6 @@ static bool is_fp8_ocp() {
...
@@ -31,29 +23,6 @@ static bool is_fp8_ocp() {
#endif
#endif
}
}
template
<
typename
T
>
struct
fp8_e4m3_adjusted_max
;
template
<
>
struct
fp8_e4m3_adjusted_max
<
c10
::
Float8_e4m3fn
>
{
static
constexpr
c10
::
Float8_e4m3fn
val
()
{
return
std
::
numeric_limits
<
c10
::
Float8_e4m3fn
>::
max
();
}
};
// Using the default max value from pytorch (240.0 0x7F) will cause accuracy
// issues when running dynamic quantization. Here use 224.0 0x7E for rocm.
template
<
>
struct
fp8_e4m3_adjusted_max
<
c10
::
Float8_e4m3fnuz
>
{
static
constexpr
c10
::
Float8_e4m3fnuz
val
()
{
return
c10
::
Float8_e4m3fnuz
(
0x7E
,
c10
::
Float8_e4m3fnuz
::
from_bits
());
}
};
template
<
typename
T
>
MAYBE_HOST_DEVICE
static
constexpr
T
fp8_e4m3_adjusted_max_v
=
fp8_e4m3_adjusted_max
<
T
>::
val
();
namespace
vllm
{
namespace
vllm
{
__device__
__forceinline__
float
atomicMaxFloat
(
float
*
addr
,
float
value
)
{
__device__
__forceinline__
float
atomicMaxFloat
(
float
*
addr
,
float
value
)
{
...
@@ -76,8 +45,8 @@ __device__ __forceinline__ fp8_type scaled_fp8_conversion(float const val,
...
@@ -76,8 +45,8 @@ __device__ __forceinline__ fp8_type scaled_fp8_conversion(float const val,
x
=
val
/
scale
;
x
=
val
/
scale
;
}
}
float
r
=
fmax
(
-
fp8_e4m3_adjusted_max_v
<
fp8_type
>
,
float
r
=
fmin
(
x
,
fp8_e4m3_adjusted
_max_v
<
fp8_type
>
));
fmax
(
-
quant_type_max_v
<
fp8_type
>
,
fmin
(
x
,
quant_type
_max_v
<
fp8_type
>
));
#ifndef USE_ROCM
#ifndef USE_ROCM
return
static_cast
<
fp8_type
>
(
r
);
return
static_cast
<
fp8_type
>
(
r
);
#else
#else
...
@@ -123,7 +92,7 @@ __global__ void segmented_max_reduction(float* __restrict__ scale,
...
@@ -123,7 +92,7 @@ __global__ void segmented_max_reduction(float* __restrict__ scale,
// Finally, since cache[0] contains the maximum for this thread block,
// Finally, since cache[0] contains the maximum for this thread block,
// atomically write the max to the target location
// atomically write the max to the target location
if
(
threadIdx
.
x
==
0
)
{
if
(
threadIdx
.
x
==
0
)
{
atomicMaxFloat
(
scale
,
cache
[
0
]
/
fp8_e4m3_adjusted
_max_v
<
fp8_type
>
);
atomicMaxFloat
(
scale
,
cache
[
0
]
/
quant_type
_max_v
<
fp8_type
>
);
}
}
}
}
...
...
csrc/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu
View file @
fcfc474d
...
@@ -14,8 +14,7 @@ __device__ void rms_norm_dynamic_per_token_quant_vec(
...
@@ -14,8 +14,7 @@ __device__ void rms_norm_dynamic_per_token_quant_vec(
float
*
__restrict__
scales
,
// [num_tokens]
float
*
__restrict__
scales
,
// [num_tokens]
scalar_t
const
*
__restrict__
input
,
// [..., hidden_size]
scalar_t
const
*
__restrict__
input
,
// [..., hidden_size]
scalar_t
const
*
__restrict__
weight
,
// [hidden_size]
scalar_t
const
*
__restrict__
weight
,
// [hidden_size]
float
const
*
scale_ub
,
float
const
var_epsilon
,
float
const
*
scale_ub
,
float
const
var_epsilon
,
int32_t
const
hidden_size
,
float
const
min_scaling_factor
,
int32_t
const
hidden_size
,
scalar_t
*
__restrict__
residual
=
nullptr
)
{
scalar_t
*
__restrict__
residual
=
nullptr
)
{
float
rms
=
0.0
f
;
float
rms
=
0.0
f
;
float
token_scale
=
0.0
f
;
float
token_scale
=
0.0
f
;
...
@@ -27,8 +26,8 @@ __device__ void rms_norm_dynamic_per_token_quant_vec(
...
@@ -27,8 +26,8 @@ __device__ void rms_norm_dynamic_per_token_quant_vec(
// Compute scale
// Compute scale
vllm
::
vectorized
::
compute_dynamic_per_token_scales
<
scalar_t
,
scalar_out_t
,
vllm
::
vectorized
::
compute_dynamic_per_token_scales
<
scalar_t
,
scalar_out_t
,
has_residual
>
(
has_residual
>
(
&
token_scale
,
scales
,
input
,
weight
,
rms
,
scale_ub
,
min_scaling_factor
,
&
token_scale
,
scales
,
input
,
weight
,
rms
,
scale_ub
,
hidden_size
,
hidden_size
,
residual
);
residual
);
// RMS Norm + Quant
// RMS Norm + Quant
if
constexpr
(
std
::
is_same_v
<
scalar_out_t
,
int8_t
>
)
{
if
constexpr
(
std
::
is_same_v
<
scalar_out_t
,
int8_t
>
)
{
...
@@ -50,8 +49,7 @@ __global__ void rms_norm_dynamic_per_token_quant_kernel(
...
@@ -50,8 +49,7 @@ __global__ void rms_norm_dynamic_per_token_quant_kernel(
float
*
__restrict__
scales
,
// [num_tokens]
float
*
__restrict__
scales
,
// [num_tokens]
scalar_t
const
*
__restrict__
input
,
// [..., hidden_size]
scalar_t
const
*
__restrict__
input
,
// [..., hidden_size]
scalar_t
const
*
__restrict__
weight
,
// [hidden_size]
scalar_t
const
*
__restrict__
weight
,
// [hidden_size]
float
const
*
scale_ub
,
float
const
var_epsilon
,
float
const
*
scale_ub
,
float
const
var_epsilon
,
int32_t
const
hidden_size
,
float
const
min_scaling_factor
,
int32_t
const
hidden_size
,
scalar_t
*
__restrict__
residual
=
nullptr
)
{
scalar_t
*
__restrict__
residual
=
nullptr
)
{
// For vectorization, token_input and token_output pointers need to be
// For vectorization, token_input and token_output pointers need to be
// aligned at 8-byte and 4-byte addresses respectively.
// aligned at 8-byte and 4-byte addresses respectively.
...
@@ -60,8 +58,8 @@ __global__ void rms_norm_dynamic_per_token_quant_kernel(
...
@@ -60,8 +58,8 @@ __global__ void rms_norm_dynamic_per_token_quant_kernel(
if
(
can_vectorize
)
{
if
(
can_vectorize
)
{
return
rms_norm_dynamic_per_token_quant_vec
<
scalar_t
,
scalar_out_t
,
return
rms_norm_dynamic_per_token_quant_vec
<
scalar_t
,
scalar_out_t
,
has_residual
>
(
has_residual
>
(
out
,
scales
,
input
,
weight
,
scale_ub
,
var_epsilon
,
min_scaling_factor
,
out
,
scales
,
input
,
weight
,
scale_ub
,
var_epsilon
,
hidden_size
,
hidden_size
,
residual
);
residual
);
}
}
float
rms
=
0.0
f
;
float
rms
=
0.0
f
;
...
@@ -72,8 +70,8 @@ __global__ void rms_norm_dynamic_per_token_quant_kernel(
...
@@ -72,8 +70,8 @@ __global__ void rms_norm_dynamic_per_token_quant_kernel(
var_epsilon
,
residual
);
var_epsilon
,
residual
);
// Compute Scale
// Compute Scale
vllm
::
compute_dynamic_per_token_scales
<
scalar_t
,
scalar_out_t
,
has_residual
>
(
vllm
::
compute_dynamic_per_token_scales
<
scalar_t
,
scalar_out_t
,
has_residual
>
(
&
token_scale
,
scales
,
input
,
weight
,
rms
,
scale_ub
,
min_scaling_factor
,
&
token_scale
,
scales
,
input
,
weight
,
rms
,
scale_ub
,
hidden_size
,
hidden_size
,
residual
);
residual
);
// RMS Norm + Quant
// RMS Norm + Quant
if
constexpr
(
std
::
is_same_v
<
scalar_out_t
,
int8_t
>
)
{
if
constexpr
(
std
::
is_same_v
<
scalar_out_t
,
int8_t
>
)
{
...
@@ -105,11 +103,6 @@ void rms_norm_dynamic_per_token_quant_dispatch(
...
@@ -105,11 +103,6 @@ void rms_norm_dynamic_per_token_quant_dispatch(
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
input
));
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
input
));
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
const
float
min_scaling_factor
=
out
.
dtype
()
==
torch
::
kInt8
?
std
::
numeric_limits
<
float
>::
epsilon
()
:
1.0
f
/
(
std
::
numeric_limits
<
c10
::
Float8_e4m3fn
>::
max
()
*
512.
f
);
if
(
residual
.
has_value
())
{
if
(
residual
.
has_value
())
{
VLLM_DISPATCH_QUANT_TYPES
(
VLLM_DISPATCH_QUANT_TYPES
(
out
.
scalar_type
(),
"rms_norm_dynamic_per_token_quant_kernel"
,
[
&
]
{
out
.
scalar_type
(),
"rms_norm_dynamic_per_token_quant_kernel"
,
[
&
]
{
...
@@ -119,8 +112,7 @@ void rms_norm_dynamic_per_token_quant_dispatch(
...
@@ -119,8 +112,7 @@ void rms_norm_dynamic_per_token_quant_dispatch(
out
.
data_ptr
<
scalar_t
>
(),
scales
.
data_ptr
<
float
>
(),
out
.
data_ptr
<
scalar_t
>
(),
scales
.
data_ptr
<
float
>
(),
input
.
data_ptr
<
scalar_in_t
>
(),
weight
.
data_ptr
<
scalar_in_t
>
(),
input
.
data_ptr
<
scalar_in_t
>
(),
weight
.
data_ptr
<
scalar_in_t
>
(),
scale_ub
.
has_value
()
?
scale_ub
->
data_ptr
<
float
>
()
:
nullptr
,
scale_ub
.
has_value
()
?
scale_ub
->
data_ptr
<
float
>
()
:
nullptr
,
var_epsilon
,
min_scaling_factor
,
hidden_size
,
var_epsilon
,
hidden_size
,
residual
->
data_ptr
<
scalar_in_t
>
());
residual
->
data_ptr
<
scalar_in_t
>
());
});
});
}
else
{
}
else
{
...
@@ -132,7 +124,7 @@ void rms_norm_dynamic_per_token_quant_dispatch(
...
@@ -132,7 +124,7 @@ void rms_norm_dynamic_per_token_quant_dispatch(
out
.
data_ptr
<
scalar_t
>
(),
scales
.
data_ptr
<
float
>
(),
out
.
data_ptr
<
scalar_t
>
(),
scales
.
data_ptr
<
float
>
(),
input
.
data_ptr
<
scalar_in_t
>
(),
weight
.
data_ptr
<
scalar_in_t
>
(),
input
.
data_ptr
<
scalar_in_t
>
(),
weight
.
data_ptr
<
scalar_in_t
>
(),
scale_ub
.
has_value
()
?
scale_ub
->
data_ptr
<
float
>
()
:
nullptr
,
scale_ub
.
has_value
()
?
scale_ub
->
data_ptr
<
float
>
()
:
nullptr
,
var_epsilon
,
min_scaling_factor
,
hidden_size
,
nullptr
);
var_epsilon
,
hidden_size
,
nullptr
);
});
});
}
}
}
}
...
...
csrc/quantization/fused_kernels/layernorm_utils.cuh
View file @
fcfc474d
...
@@ -5,6 +5,7 @@
...
@@ -5,6 +5,7 @@
*/
*/
#include "quantization/vectorization.cuh"
#include "quantization/vectorization.cuh"
#include "quantization/utils.cuh"
#include "quant_conversions.cuh"
#include "quant_conversions.cuh"
#ifndef USE_ROCM
#ifndef USE_ROCM
...
@@ -51,11 +52,11 @@ __device__ void compute_dynamic_per_token_scales(
...
@@ -51,11 +52,11 @@ __device__ void compute_dynamic_per_token_scales(
float
*
__restrict__
token_scale
,
float
*
__restrict__
all_token_scales
,
float
*
__restrict__
token_scale
,
float
*
__restrict__
all_token_scales
,
scalar_t
const
*
__restrict__
input
,
scalar_t
const
*
__restrict__
weight
,
scalar_t
const
*
__restrict__
input
,
scalar_t
const
*
__restrict__
weight
,
float
const
rms
,
float
const
*
__restrict__
scale_ub
,
float
const
rms
,
float
const
*
__restrict__
scale_ub
,
float
const
min_scaling_factor
,
int32_t
const
hidden_size
,
int32_t
const
hidden_size
,
scalar_t
const
*
__restrict__
residual
=
nullptr
)
{
scalar_t
const
*
__restrict__
residual
=
nullptr
)
{
int64_t
const
token_offset
=
blockIdx
.
x
*
static_cast
<
int64_t
>
(
hidden_size
);
int64_t
const
token_offset
=
blockIdx
.
x
*
static_cast
<
int64_t
>
(
hidden_size
);
;
;
constexpr
scalar_out_t
qmax
{
std
::
numeric_limits
<
scalar_out_t
>
::
max
()
};
constexpr
scalar_out_t
qmax
{
quant_type_max_v
<
scalar_out_t
>
};
float
block_absmax_val_maybe
=
0.0
f
;
float
block_absmax_val_maybe
=
0.0
f
;
for
(
auto
i
=
threadIdx
.
x
;
i
<
hidden_size
;
i
+=
blockDim
.
x
)
{
for
(
auto
i
=
threadIdx
.
x
;
i
<
hidden_size
;
i
+=
blockDim
.
x
)
{
...
@@ -83,7 +84,7 @@ __device__ void compute_dynamic_per_token_scales(
...
@@ -83,7 +84,7 @@ __device__ void compute_dynamic_per_token_scales(
scale
=
block_absmax_val_maybe
;
scale
=
block_absmax_val_maybe
;
}
}
// token scale computation
// token scale computation
scale
=
max
(
scale
/
qmax
,
min_scaling_factor
);
scale
=
max
(
scale
/
qmax
,
min_scaling_factor
<
scalar_out_t
>::
val
()
);
s_token_scale
=
scale
;
// Shared memory store
s_token_scale
=
scale
;
// Shared memory store
all_token_scales
[
blockIdx
.
x
]
=
scale
;
// Global output store
all_token_scales
[
blockIdx
.
x
]
=
scale
;
// Global output store
}
}
...
@@ -184,7 +185,7 @@ __device__ void compute_dynamic_per_token_scales(
...
@@ -184,7 +185,7 @@ __device__ void compute_dynamic_per_token_scales(
float
*
__restrict__
token_scale
,
float
*
__restrict__
all_token_scales
,
float
*
__restrict__
token_scale
,
float
*
__restrict__
all_token_scales
,
scalar_t
const
*
__restrict__
input
,
scalar_t
const
*
__restrict__
weight
,
scalar_t
const
*
__restrict__
input
,
scalar_t
const
*
__restrict__
weight
,
float
const
rms
,
float
const
*
__restrict__
scale_ub
,
float
const
rms
,
float
const
*
__restrict__
scale_ub
,
float
const
min_scaling_factor
,
int32_t
const
hidden_size
,
int32_t
const
hidden_size
,
scalar_t
const
*
__restrict__
residual
=
nullptr
)
{
scalar_t
const
*
__restrict__
residual
=
nullptr
)
{
int64_t
const
token_offset
=
blockIdx
.
x
*
static_cast
<
int64_t
>
(
hidden_size
);
int64_t
const
token_offset
=
blockIdx
.
x
*
static_cast
<
int64_t
>
(
hidden_size
);
;
;
...
@@ -200,7 +201,7 @@ __device__ void compute_dynamic_per_token_scales(
...
@@ -200,7 +201,7 @@ __device__ void compute_dynamic_per_token_scales(
reinterpret_cast
<
vec4_t
<
scalar_t
>
const
*>
(
&
residual
[
token_offset
]);
reinterpret_cast
<
vec4_t
<
scalar_t
>
const
*>
(
&
residual
[
token_offset
]);
}
}
constexpr
scalar_out_t
qmax
{
std
::
numeric_limits
<
scalar_out_t
>
::
max
()
};
constexpr
scalar_out_t
qmax
{
quant_type_max_v
<
scalar_out_t
>
};
int32_t
const
num_vec_elems
=
hidden_size
>>
2
;
int32_t
const
num_vec_elems
=
hidden_size
>>
2
;
float
block_absmax_val_maybe
=
0.0
f
;
float
block_absmax_val_maybe
=
0.0
f
;
...
@@ -248,7 +249,7 @@ __device__ void compute_dynamic_per_token_scales(
...
@@ -248,7 +249,7 @@ __device__ void compute_dynamic_per_token_scales(
scale
=
block_absmax_val_maybe
;
scale
=
block_absmax_val_maybe
;
}
}
// token scale computation
// token scale computation
scale
=
max
(
scale
/
qmax
,
min_scaling_factor
);
scale
=
max
(
scale
/
qmax
,
min_scaling_factor
<
scalar_out_t
>::
val
()
);
s_token_scale
=
scale
;
// shared memory store
s_token_scale
=
scale
;
// shared memory store
all_token_scales
[
blockIdx
.
x
]
=
scale
;
// global output store
all_token_scales
[
blockIdx
.
x
]
=
scale
;
// global output store
}
}
...
...
csrc/quantization/fused_kernels/quant_conversions.cuh
View file @
fcfc474d
...
@@ -33,8 +33,8 @@ static __device__ __forceinline__ int8_t float_to_int8_rn(float const x) {
...
@@ -33,8 +33,8 @@ static __device__ __forceinline__ int8_t float_to_int8_rn(float const x) {
template
<
typename
fp8_type
>
template
<
typename
fp8_type
>
static
__device__
__forceinline__
fp8_type
float_to_fp8
(
float
const
x
)
{
static
__device__
__forceinline__
fp8_type
float_to_fp8
(
float
const
x
)
{
float
const
r
=
fmax
(
-
fp8_e4m3_adjusted_max_v
<
fp8_type
>
,
float
const
r
=
fmin
(
x
,
fp8_e4m3_adjusted
_max_v
<
fp8_type
>
));
fmax
(
-
quant_type_max_v
<
fp8_type
>
,
fmin
(
x
,
quant_type
_max_v
<
fp8_type
>
));
return
static_cast
<
fp8_type
>
(
r
);
return
static_cast
<
fp8_type
>
(
r
);
}
}
...
...
Prev
1
2
3
4
5
6
7
…
26
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment