Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
66b809cc
Commit
66b809cc
authored
Feb 08, 2025
by
zhuwenwen
Browse files
Merge tag 'v0.7.2' into v0.7.2-dev
parents
37b63c24
0408efc6
Changes
1000
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
231 additions
and
49 deletions
+231
-49
benchmarks/kernels/benchmark_moe.py
benchmarks/kernels/benchmark_moe.py
+2
-0
benchmarks/kernels/benchmark_paged_attention.py
benchmarks/kernels/benchmark_paged_attention.py
+2
-0
benchmarks/kernels/benchmark_quant.py
benchmarks/kernels/benchmark_quant.py
+2
-0
benchmarks/kernels/benchmark_rmsnorm.py
benchmarks/kernels/benchmark_rmsnorm.py
+2
-0
benchmarks/kernels/benchmark_rope.py
benchmarks/kernels/benchmark_rope.py
+2
-0
benchmarks/kernels/benchmark_shapes.py
benchmarks/kernels/benchmark_shapes.py
+2
-0
benchmarks/kernels/graph_machete_bench.py
benchmarks/kernels/graph_machete_bench.py
+2
-0
benchmarks/kernels/utils.py
benchmarks/kernels/utils.py
+2
-0
benchmarks/kernels/weight_shapes.py
benchmarks/kernels/weight_shapes.py
+2
-0
benchmarks/overheads/benchmark_hashing.py
benchmarks/overheads/benchmark_hashing.py
+2
-0
cmake/hipify.py
cmake/hipify.py
+1
-0
collect_env.py
collect_env.py
+2
-0
csrc/cache.h
csrc/cache.h
+3
-0
csrc/cache_kernels.cu
csrc/cache_kernels.cu
+70
-14
csrc/cutlass_extensions/vllm_cutlass_library_extension.py
csrc/cutlass_extensions/vllm_cutlass_library_extension.py
+2
-0
csrc/moe/moe_align_sum_kernels.cu
csrc/moe/moe_align_sum_kernels.cu
+92
-0
csrc/moe/moe_ops.h
csrc/moe/moe_ops.h
+6
-0
csrc/moe/torch_bindings.cpp
csrc/moe/torch_bindings.cpp
+9
-0
csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu
csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu
+24
-35
csrc/quantization/machete/generate.py
csrc/quantization/machete/generate.py
+2
-0
No files found.
Too many changes to show.
To preserve performance only
1000 of 1000+
files are displayed.
Plain diff
Email patch
benchmarks/kernels/benchmark_moe.py
View file @
66b809cc
# SPDX-License-Identifier: Apache-2.0
import
argparse
import
time
from
datetime
import
datetime
...
...
benchmarks/kernels/benchmark_paged_attention.py
View file @
66b809cc
# SPDX-License-Identifier: Apache-2.0
import
random
import
time
from
typing
import
List
,
Optional
...
...
benchmarks/kernels/benchmark_quant.py
View file @
66b809cc
# SPDX-License-Identifier: Apache-2.0
import
time
import
torch
...
...
benchmarks/kernels/benchmark_rmsnorm.py
View file @
66b809cc
# SPDX-License-Identifier: Apache-2.0
import
itertools
from
typing
import
Optional
,
Tuple
,
Union
...
...
benchmarks/kernels/benchmark_rope.py
View file @
66b809cc
# SPDX-License-Identifier: Apache-2.0
from
itertools
import
accumulate
from
typing
import
List
,
Optional
...
...
benchmarks/kernels/benchmark_shapes.py
View file @
66b809cc
# SPDX-License-Identifier: Apache-2.0
WEIGHT_SHAPES
=
{
"ideal"
:
[[
4
*
256
*
32
,
256
*
32
]],
"mistralai/Mistral-7B-v0.1/TP1"
:
[
...
...
benchmarks/kernels/graph_machete_bench.py
View file @
66b809cc
# SPDX-License-Identifier: Apache-2.0
import
math
import
pickle
import
re
...
...
benchmarks/kernels/utils.py
View file @
66b809cc
# SPDX-License-Identifier: Apache-2.0
import
dataclasses
from
typing
import
Any
,
Callable
,
Iterable
,
Optional
...
...
benchmarks/kernels/weight_shapes.py
View file @
66b809cc
# SPDX-License-Identifier: Apache-2.0
# Weight Shapes are in the format
# ([K, N], TP_SPLIT_DIM)
# Example:
...
...
benchmarks/overheads/benchmark_hashing.py
View file @
66b809cc
# SPDX-License-Identifier: Apache-2.0
import
cProfile
import
pstats
...
...
cmake/hipify.py
View file @
66b809cc
#!/usr/bin/env python3
# SPDX-License-Identifier: Apache-2.0
#
# A command line tool for running pytorch's hipify preprocessor on CUDA
...
...
collect_env.py
View file @
66b809cc
# SPDX-License-Identifier: Apache-2.0
# ruff: noqa
# code borrowed from https://github.com/pytorch/pytorch/blob/main/torch/utils/collect_env.py
...
...
csrc/cache.h
View file @
66b809cc
...
...
@@ -15,6 +15,9 @@ void copy_blocks(std::vector<torch::Tensor> const& key_caches,
std
::
vector
<
torch
::
Tensor
>
const
&
value_caches
,
const
torch
::
Tensor
&
block_mapping
);
void
copy_blocks_mla
(
std
::
vector
<
torch
::
Tensor
>
const
&
kv_caches
,
const
torch
::
Tensor
&
block_mapping
);
void
reshape_and_cache
(
torch
::
Tensor
&
key
,
torch
::
Tensor
&
value
,
torch
::
Tensor
&
key_cache
,
torch
::
Tensor
&
value_cache
,
torch
::
Tensor
&
slot_mapping
,
...
...
csrc/cache_kernels.cu
View file @
66b809cc
...
...
@@ -46,7 +46,10 @@ void swap_blocks(torch::Tensor& src, torch::Tensor& dst,
char
*
src_ptr
=
static_cast
<
char
*>
(
src
.
data_ptr
());
char
*
dst_ptr
=
static_cast
<
char
*>
(
dst
.
data_ptr
());
const
int64_t
block_size_in_bytes
=
src
.
element_size
()
*
src
[
0
].
numel
();
// We use the stride instead of numel in case the cache is padded for memory
// alignment reasons, we assume the blocks data (inclusive of any padding)
// is contiguous in memory
const
int64_t
block_size_in_bytes
=
src
.
element_size
()
*
src
.
stride
(
0
);
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
src_device
.
is_cuda
()
?
src_device
:
dst_device
);
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
...
...
@@ -93,6 +96,24 @@ __global__ void copy_blocks_kernel(int64_t* key_cache_ptrs,
}
}
// Kernel for MLA, which works on a single joint kv_cache
// Grid: (num_layers, num_pairs)
template
<
typename
scalar_t
>
__global__
void
copy_blocks_mla_kernel
(
int64_t
*
cache_ptrs
,
const
int64_t
*
__restrict__
block_mapping
,
const
int
mem_footprint_per_block
)
{
const
int
layer_idx
=
blockIdx
.
x
;
const
int
pair_idx
=
blockIdx
.
y
;
scalar_t
*
cache
=
reinterpret_cast
<
scalar_t
*>
(
cache_ptrs
[
layer_idx
]);
int64_t
src_block
=
block_mapping
[
2
*
pair_idx
];
int64_t
dst_block
=
block_mapping
[
2
*
pair_idx
+
1
];
int64_t
src_offset
=
src_block
*
mem_footprint_per_block
;
int64_t
dst_offset
=
dst_block
*
mem_footprint_per_block
;
for
(
int
i
=
threadIdx
.
x
;
i
<
mem_footprint_per_block
;
i
+=
blockDim
.
x
)
{
cache
[
dst_offset
+
i
]
=
cache
[
src_offset
+
i
];
}
}
}
// namespace vllm
// Note: the key_caches and value_caches vectors are constant but
...
...
@@ -147,6 +168,42 @@ void copy_blocks(std::vector<torch::Tensor> const& key_caches,
}));
}
// copy blocks kernel for MLA (assumes a joint KV-cache)
void
copy_blocks_mla
(
std
::
vector
<
torch
::
Tensor
>
const
&
kv_caches
,
const
torch
::
Tensor
&
block_mapping
)
{
int
num_layers
=
kv_caches
.
size
();
if
(
num_layers
==
0
)
{
return
;
}
torch
::
Device
cache_device
=
kv_caches
[
0
].
device
();
TORCH_CHECK
(
cache_device
.
is_cuda
(),
"kv_cache must be on CUDA"
);
std
::
vector
<
int64_t
>
cache_ptrs
(
num_layers
);
for
(
int
layer_idx
=
0
;
layer_idx
<
num_layers
;
++
layer_idx
)
{
cache_ptrs
[
layer_idx
]
=
reinterpret_cast
<
int64_t
>
(
kv_caches
[
layer_idx
].
data_ptr
());
}
torch
::
Tensor
cache_ptrs_tensor
=
torch
::
from_blob
(
cache_ptrs
.
data
(),
{
num_layers
},
torch
::
kInt64
)
.
to
(
cache_device
);
int
num_pairs
=
block_mapping
.
size
(
0
);
// We use the stride instead of numel in case the cache is padded for memory
// alignment reasons, we assume the blocks data (inclusive of any padding)
// is contiguous in memory
int
mem_footprint_per_block
=
kv_caches
[
0
].
stride
(
0
);
dim3
grid
(
num_layers
,
num_pairs
);
dim3
block
(
std
::
min
(
1024
,
mem_footprint_per_block
));
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
cache_device
);
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
VLLM_DISPATCH_FLOATING_AND_BYTE_TYPES
(
kv_caches
[
0
].
scalar_type
(),
"copy_blocks_mla_kernel"
,
([
&
]
{
vllm
::
copy_blocks_mla_kernel
<
scalar_t
><<<
grid
,
block
,
0
,
stream
>>>
(
cache_ptrs_tensor
.
data_ptr
<
int64_t
>
(),
block_mapping
.
data_ptr
<
int64_t
>
(),
mem_footprint_per_block
);
}));
}
namespace
vllm
{
template
<
typename
scalar_t
,
typename
cache_t
,
Fp8KVCacheDataType
kv_dt
>
...
...
@@ -382,6 +439,7 @@ __global__ void concat_and_cache_mla_kernel(
// + pe_dim)]
const
int64_t
*
__restrict__
slot_mapping
,
// [num_tokens]
const
int
block_stride
,
//
const
int
entry_stride
,
//
const
int
kv_c_stride
,
//
const
int
k_pe_stride
,
//
const
int
kv_lora_rank
,
//
...
...
@@ -402,9 +460,8 @@ __global__ void concat_and_cache_mla_kernel(
int
src_stride
,
int
dst_stride
,
int
size
,
int
offset
)
{
for
(
int
i
=
threadIdx
.
x
;
i
<
size
;
i
+=
blockDim
.
x
)
{
const
int64_t
src_idx
=
token_idx
*
src_stride
+
i
;
const
int64_t
dst_idx
=
block_idx
*
block_stride
+
block_offset
*
(
kv_lora_rank
+
pe_dim
)
+
i
+
offset
;
const
int64_t
dst_idx
=
block_idx
*
block_stride
+
block_offset
*
entry_stride
+
i
+
offset
;
if
constexpr
(
kv_dt
==
Fp8KVCacheDataType
::
kAuto
)
{
dst
[
dst_idx
]
=
src
[
src_idx
];
}
else
{
...
...
@@ -660,16 +717,14 @@ void write_cache_multi_layers(
CALL_WRITE_CACHE_MULTI_LAYERS
);
}
#define CALL_CONCAT_AND_CACHE_MLA(KV_T, CACHE_T, KV_DTYPE) \
vllm::concat_and_cache_mla_kernel<KV_T, CACHE_T, KV_DTYPE> \
<<<grid, block, 0, stream>>>( \
reinterpret_cast<KV_T*>(kv_c.data_ptr()), \
reinterpret_cast<KV_T*>(k_pe.data_ptr()), \
reinterpret_cast<CACHE_T*>(kv_cache.data_ptr()), \
slot_mapping.data_ptr<int64_t>(), block_stride, kv_c_stride, \
k_pe_stride, kv_lora_rank, pe_dim, block_size, \
#define CALL_CONCAT_AND_CACHE_MLA(KV_T, CACHE_T, KV_DTYPE) \
vllm::concat_and_cache_mla_kernel<KV_T, CACHE_T, KV_DTYPE> \
<<<grid, block, 0, stream>>>( \
reinterpret_cast<KV_T*>(kv_c.data_ptr()), \
reinterpret_cast<KV_T*>(k_pe.data_ptr()), \
reinterpret_cast<CACHE_T*>(kv_cache.data_ptr()), \
slot_mapping.data_ptr<int64_t>(), block_stride, entry_stride, \
kv_c_stride, k_pe_stride, kv_lora_rank, pe_dim, block_size, \
reinterpret_cast<const float*>(scale.data_ptr()));
void
concat_and_cache_mla
(
...
...
@@ -699,6 +754,7 @@ void concat_and_cache_mla(
int
kv_c_stride
=
kv_c
.
stride
(
0
);
int
k_pe_stride
=
k_pe
.
stride
(
0
);
int
block_stride
=
kv_cache
.
stride
(
0
);
int
entry_stride
=
kv_cache
.
stride
(
1
);
dim3
grid
(
num_tokens
);
dim3
block
(
std
::
min
(
kv_lora_rank
,
512
));
...
...
csrc/cutlass_extensions/vllm_cutlass_library_extension.py
View file @
66b809cc
# SPDX-License-Identifier: Apache-2.0
import
enum
from
typing
import
Dict
,
Union
...
...
csrc/moe/moe_align_sum_kernels.cu
View file @
66b809cc
...
...
@@ -197,6 +197,72 @@ __global__ void moe_align_block_size_global_mem_kernel(
}
}
// taken from
// https://github.com/sgl-project/sglang/commit/ded9fcd09a43d5e7d5bb31a2bc3e9fc21bf65d2a
template
<
typename
scalar_t
>
__global__
void
sgl_moe_align_block_size_kernel
(
scalar_t
*
__restrict__
topk_ids
,
int32_t
*
sorted_token_ids
,
int32_t
*
expert_ids
,
int32_t
*
total_tokens_post_pad
,
int32_t
num_experts
,
int32_t
block_size
,
size_t
numel
,
int32_t
*
cumsum
)
{
__shared__
int32_t
shared_counts
[
32
][
8
];
__shared__
int32_t
local_offsets
[
256
];
const
int
warp_id
=
threadIdx
.
x
/
32
;
const
int
lane_id
=
threadIdx
.
x
%
32
;
const
int
experts_per_warp
=
8
;
const
int
my_expert_start
=
warp_id
*
experts_per_warp
;
for
(
int
i
=
0
;
i
<
experts_per_warp
;
++
i
)
{
if
(
my_expert_start
+
i
<
num_experts
)
{
shared_counts
[
warp_id
][
i
]
=
0
;
}
}
const
size_t
tokens_per_thread
=
CEILDIV
(
numel
,
blockDim
.
x
);
const
size_t
start_idx
=
threadIdx
.
x
*
tokens_per_thread
;
for
(
int
i
=
start_idx
;
i
<
numel
&&
i
<
start_idx
+
tokens_per_thread
;
++
i
)
{
int
expert_id
=
topk_ids
[
i
];
int
warp_idx
=
expert_id
/
experts_per_warp
;
int
expert_offset
=
expert_id
%
experts_per_warp
;
atomicAdd
(
&
shared_counts
[
warp_idx
][
expert_offset
],
1
);
}
__syncthreads
();
if
(
threadIdx
.
x
==
0
)
{
cumsum
[
0
]
=
0
;
for
(
int
i
=
1
;
i
<=
num_experts
;
++
i
)
{
int
expert_count
=
0
;
int
warp_idx
=
(
i
-
1
)
/
experts_per_warp
;
int
expert_offset
=
(
i
-
1
)
%
experts_per_warp
;
expert_count
=
shared_counts
[
warp_idx
][
expert_offset
];
cumsum
[
i
]
=
cumsum
[
i
-
1
]
+
CEILDIV
(
expert_count
,
block_size
)
*
block_size
;
}
*
total_tokens_post_pad
=
cumsum
[
num_experts
];
}
__syncthreads
();
if
(
threadIdx
.
x
<
num_experts
)
{
for
(
int
i
=
cumsum
[
threadIdx
.
x
];
i
<
cumsum
[
threadIdx
.
x
+
1
];
i
+=
block_size
)
{
expert_ids
[
i
/
block_size
]
=
threadIdx
.
x
;
}
local_offsets
[
threadIdx
.
x
]
=
cumsum
[
threadIdx
.
x
];
}
__syncthreads
();
for
(
int
i
=
start_idx
;
i
<
numel
&&
i
<
start_idx
+
tokens_per_thread
;
++
i
)
{
int32_t
expert_id
=
topk_ids
[
i
];
int32_t
rank_post_pad
=
atomicAdd
(
&
local_offsets
[
expert_id
],
1
);
sorted_token_ids
[
rank_post_pad
]
=
i
;
}
}
template
<
typename
scalar_t
,
int
TOPK
>
__global__
void
moe_sum_kernel
(
scalar_t
*
__restrict__
out
,
// [..., d]
...
...
@@ -305,6 +371,32 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
}
}
void
sgl_moe_align_block_size
(
torch
::
Tensor
topk_ids
,
int64_t
num_experts
,
int64_t
block_size
,
torch
::
Tensor
sorted_token_ids
,
torch
::
Tensor
experts_ids
,
torch
::
Tensor
num_tokens_post_pad
)
{
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
VLLM_DISPATCH_INTEGRAL_TYPES
(
topk_ids
.
scalar_type
(),
"sgl_moe_align_block_size_kernel"
,
[
&
]
{
// calc needed amount of shared mem for `tokens_cnts` and `cumsum`
// tensors
auto
options_int
=
torch
::
TensorOptions
().
dtype
(
torch
::
kInt
).
device
(
topk_ids
.
device
());
// torch::Tensor token_cnts_buffer =
// torch::empty({(num_experts + 1) * num_experts}, options_int);
torch
::
Tensor
cumsum_buffer
=
torch
::
empty
({
num_experts
+
1
},
options_int
);
auto
kernel
=
vllm
::
moe
::
sgl_moe_align_block_size_kernel
<
scalar_t
>
;
kernel
<<<
1
,
1024
,
0
,
stream
>>>
(
topk_ids
.
data_ptr
<
scalar_t
>
(),
sorted_token_ids
.
data_ptr
<
int32_t
>
(),
experts_ids
.
data_ptr
<
int32_t
>
(),
num_tokens_post_pad
.
data_ptr
<
int32_t
>
(),
num_experts
,
block_size
,
topk_ids
.
numel
(),
cumsum_buffer
.
data_ptr
<
int32_t
>
());
});
}
void
moe_sum
(
torch
::
Tensor
&
input
,
// [num_tokens, topk, hidden_size]
torch
::
Tensor
&
output
)
// [num_tokens, hidden_size]
{
...
...
csrc/moe/moe_ops.h
View file @
66b809cc
...
...
@@ -12,3 +12,9 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
int64_t
block_size
,
torch
::
Tensor
sorted_token_ids
,
torch
::
Tensor
experts_ids
,
torch
::
Tensor
num_tokens_post_pad
);
void
sgl_moe_align_block_size
(
torch
::
Tensor
topk_ids
,
int64_t
num_experts
,
int64_t
block_size
,
torch
::
Tensor
sorted_token_ids
,
torch
::
Tensor
experts_ids
,
torch
::
Tensor
num_tokens_post_pad
);
csrc/moe/torch_bindings.cpp
View file @
66b809cc
...
...
@@ -22,6 +22,15 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) {
" Tensor! num_tokens_post_pad) -> ()"
);
m
.
impl
(
"moe_align_block_size"
,
torch
::
kCUDA
,
&
moe_align_block_size
);
// temporarily adapted from
// https://github.com/sgl-project/sglang/commit/ded9fcd09a43d5e7d5bb31a2bc3e9fc21bf65d2a
m
.
def
(
"sgl_moe_align_block_size(Tensor topk_ids, int num_experts,"
" int block_size, Tensor! sorted_token_ids,"
" Tensor! experts_ids,"
" Tensor! num_tokens_post_pad) -> ()"
);
m
.
impl
(
"sgl_moe_align_block_size"
,
torch
::
kCUDA
,
&
sgl_moe_align_block_size
);
#ifndef USE_ROCM
m
.
def
(
"marlin_gemm_moe(Tensor! a, Tensor! b_q_weights, Tensor! sorted_ids, "
...
...
csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu
View file @
66b809cc
...
...
@@ -16,29 +16,11 @@ void cutlass_scaled_mm_sm90(torch::Tensor& c, torch::Tensor const& a,
TORCH_CHECK
(
a_scales
.
dtype
()
==
torch
::
kFloat32
);
TORCH_CHECK
(
b_scales
.
dtype
()
==
torch
::
kFloat32
);
using
GroupShape
=
std
::
array
<
int64_t
,
2
>
;
int
M
=
a
.
size
(
0
),
N
=
b
.
size
(
1
),
K
=
a
.
size
(
1
);
GroupShape
a_scale_group_shape
=
[
&
,
&
s
=
a_scales
]()
->
GroupShape
{
if
(
s
.
numel
()
==
1
)
return
{
M
,
K
};
// tensor-wise
if
(
s
.
dim
()
==
2
)
return
{
ceil_div
(
a
.
size
(
0
),
s
.
size
(
0
)),
ceil_div
(
a
.
size
(
1
),
s
.
size
(
1
))};
TORCH_CHECK
(
false
,
"Unsupported scale shape for scale_a"
);
}();
GroupShape
b_scale_group_shape
=
[
&
,
&
s
=
b_scales
]()
->
GroupShape
{
if
(
s
.
numel
()
==
1
)
return
{
K
,
N
};
// tensor-wise
if
(
s
.
dim
()
==
2
)
return
{
ceil_div
(
b
.
size
(
0
),
s
.
size
(
0
)),
ceil_div
(
b
.
size
(
1
),
s
.
size
(
1
))};
TORCH_CHECK
(
false
,
"Unsupported scale shape for scale_b"
);
}();
if
((
a_scale_group_shape
==
GroupShape
{
M
,
K
}
||
a_scale_group_shape
==
GroupShape
{
1
,
K
})
&&
(
b_scale_group_shape
==
GroupShape
{
K
,
N
}
||
b_scale_group_shape
==
GroupShape
{
K
,
1
}))
{
// "standard per-tensor/per-token/per-channel" scaling
if
((
a_scales
.
numel
()
==
1
||
a_scales
.
numel
()
==
a
.
size
(
0
))
&&
(
b_scales
.
numel
()
==
1
||
b_scales
.
numel
()
==
b
.
size
(
1
)))
{
// Standard per-tensor/per-token/per-channel scaling
TORCH_CHECK
(
a_scales
.
is_contiguous
()
&&
b_scales
.
is_contiguous
());
if
(
a
.
dtype
()
==
torch
::
kFloat8_e4m3fn
)
{
vllm
::
cutlass_scaled_mm_sm90_fp8
(
c
,
a
,
b
,
a_scales
,
b_scales
,
bias
);
...
...
@@ -46,25 +28,32 @@ void cutlass_scaled_mm_sm90(torch::Tensor& c, torch::Tensor const& a,
TORCH_CHECK
(
a
.
dtype
()
==
torch
::
kInt8
);
vllm
::
cutlass_scaled_mm_sm90_int8
(
c
,
a
,
b
,
a_scales
,
b_scales
,
bias
);
}
}
else
if
(
a_scale_group_shape
==
GroupShape
{
1
,
128
}
&&
b_scale_group_shape
==
GroupShape
{
128
,
128
})
{
}
else
{
using
GroupShape
=
std
::
array
<
int64_t
,
2
>
;
auto
make_group_shape
=
[](
torch
::
Tensor
const
&
x
,
torch
::
Tensor
const
&
s
)
->
GroupShape
{
TORCH_CHECK
(
s
.
dim
()
==
2
,
"cutlass_scaled_mm group scales must be 2D"
);
return
{
ceil_div
(
x
.
size
(
0
),
s
.
size
(
0
)),
ceil_div
(
x
.
size
(
1
),
s
.
size
(
1
))};
};
GroupShape
a_scale_group_shape
=
make_group_shape
(
a
,
a_scales
);
GroupShape
b_scale_group_shape
=
make_group_shape
(
b
,
b_scales
);
// 1x128 per-token group scales for activations
// 128x128 blockwise scales for weights
TORCH_CHECK
(
a
.
dtype
()
==
torch
::
kFloat8_e4m3fn
&&
b
.
dtype
()
==
torch
::
kFloat8_e4m3fn
,
"Currently only FP8 is supported for A group shape 1x128 and "
"B group shape 128x128"
);
TORCH_CHECK
(
!
bias
,
"Bias not yet supported blockwise scaled_mm"
);
vllm
::
cutlass_scaled_mm_blockwise_sm90_fp8
(
c
,
a
,
b
,
a_scales
,
b_scales
);
}
else
{
TORCH_CHECK
(
false
,
"Unsupported scale group shapes for CUTLASS 3.x GEMM.
\n
"
"a_scale_group_shape must be [1, 128], got: ["
,
TORCH_CHECK
((
a_scale_group_shape
==
GroupShape
{
1
,
128
}
&&
b_scale_group_shape
==
GroupShape
{
128
,
128
}
&&
a
.
dtype
()
==
torch
::
kFloat8_e4m3fn
&&
b
.
dtype
()
==
torch
::
kFloat8_e4m3fn
),
"cutlass_scaled_mm only supports datatype float8_e4m3fn.
\n
"
"a_scale_group_shape must be [1, 128]. Got: ["
,
a_scale_group_shape
[
0
],
", "
,
a_scale_group_shape
[
1
],
"]
\n
"
"b_scale_group_shape must be [128, 128]
, g
ot: ["
,
"b_scale_group_shape must be [128, 128]
. G
ot: ["
,
b_scale_group_shape
[
0
],
", "
,
b_scale_group_shape
[
1
],
"]"
);
TORCH_CHECK
(
!
bias
,
"Bias not yet supported blockwise scaled_mm"
);
vllm
::
cutlass_scaled_mm_blockwise_sm90_fp8
(
c
,
a
,
b
,
a_scales
,
b_scales
);
}
}
...
...
csrc/quantization/machete/generate.py
View file @
66b809cc
# SPDX-License-Identifier: Apache-2.0
import
itertools
import
math
import
os
...
...
Prev
1
2
3
4
5
6
7
…
50
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment