Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
99324e25
Commit
99324e25
authored
Jul 12, 2025
by
zhuwenwen
Browse files
Merge tag 'v0.9.2' into v0.9.2-ori
parents
cc7f22a8
a5dd03c1
Changes
475
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1267 additions
and
281 deletions
+1267
-281
csrc/quantization/compressed_tensors/int8_quant_kernels.cu
csrc/quantization/compressed_tensors/int8_quant_kernels.cu
+144
-104
csrc/quantization/cutlass_w8a8/c3x/scaled_mm.cuh
csrc/quantization/cutlass_w8a8/c3x/scaled_mm.cuh
+63
-1
csrc/quantization/cutlass_w8a8/c3x/scaled_mm_kernels.hpp
csrc/quantization/cutlass_w8a8/c3x/scaled_mm_kernels.hpp
+6
-0
csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm100_fp8_dispatch.cuh
...ization/cutlass_w8a8/c3x/scaled_mm_sm100_fp8_dispatch.cuh
+38
-18
csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm120_fp8.cu
csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm120_fp8.cu
+24
-0
csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm120_fp8_dispatch.cuh
...ization/cutlass_w8a8/c3x/scaled_mm_sm120_fp8_dispatch.cuh
+67
-0
csrc/quantization/cutlass_w8a8/moe/blockwise_scaled_group_mm_sm100.cu
...ation/cutlass_w8a8/moe/blockwise_scaled_group_mm_sm100.cu
+374
-0
csrc/quantization/cutlass_w8a8/scaled_mm_c3x_sm120.cu
csrc/quantization/cutlass_w8a8/scaled_mm_c3x_sm120.cu
+34
-0
csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu
csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu
+21
-5
csrc/quantization/fp4/nvfp4_blockwise_moe_kernel.cu
csrc/quantization/fp4/nvfp4_blockwise_moe_kernel.cu
+2
-0
csrc/quantization/fp4/nvfp4_experts_quant.cu
csrc/quantization/fp4/nvfp4_experts_quant.cu
+228
-50
csrc/quantization/fp4/nvfp4_quant_kernels.cu
csrc/quantization/fp4/nvfp4_quant_kernels.cu
+1
-1
csrc/quantization/fp4/nvfp4_scaled_mm_kernels.cu
csrc/quantization/fp4/nvfp4_scaled_mm_kernels.cu
+1
-1
csrc/quantization/fp8/amd/quant_utils.cuh
csrc/quantization/fp8/amd/quant_utils.cuh
+0
-2
csrc/quantization/gguf/gguf_kernel.cu
csrc/quantization/gguf/gguf_kernel.cu
+24
-23
csrc/quantization/gguf/mmvq.cuh
csrc/quantization/gguf/mmvq.cuh
+65
-62
csrc/quantization/gptq/q_gemm.cu
csrc/quantization/gptq/q_gemm.cu
+0
-8
csrc/quantization/gptq_marlin/marlin_template.h
csrc/quantization/gptq_marlin/marlin_template.h
+0
-2
csrc/quantization/machete/machete_mainloop.cuh
csrc/quantization/machete/machete_mainloop.cuh
+3
-4
csrc/quantization/vectorization_utils.cuh
csrc/quantization/vectorization_utils.cuh
+172
-0
No files found.
Too many changes to show.
To preserve performance only
475 of 475+
files are displayed.
Plain diff
Email patch
csrc/quantization/compressed_tensors/int8_quant_kernels.cu
View file @
99324e25
#include <ATen/cuda/CUDAContext.h>
#include <torch/all.h>
#include <cmath>
#include "../../dispatch_utils.h"
#include "../vectorization_utils.cuh"
#ifndef USE_ROCM
#include <cub/util_type.cuh>
#include <cub/cub.cuh>
#include <cub/util_type.cuh>
#else
#include <hipcub/util_type.hpp>
#include <hipcub/hipcub.hpp>
#include <hipcub/util_type.hpp>
#endif
static
inline
__device__
int8_t
float_to_int8_rn
(
float
x
)
{
...
...
@@ -103,134 +105,172 @@ static inline __device__ int8_t int32_to_int8(int32_t x) {
namespace
vllm
{
template
<
typename
scalar_t
,
typename
scale_t
ype
>
template
<
typename
scalar_t
,
typename
scale_t
>
__global__
void
static_scaled_int8_quant_kernel
(
scalar_t
const
*
__restrict__
input
,
int8_t
*
__restrict__
out
,
scale_type
const
*
scale_ptr
,
const
int
hidden_size
)
{
int
const
tid
=
threadIdx
.
x
;
int64_t
const
token_idx
=
blockIdx
.
x
;
scale_type
const
scale
=
*
scale_ptr
;
const
scalar_t
*
__restrict__
input
,
int8_t
*
__restrict__
output
,
const
scale_t
*
scale_ptr
,
const
int
hidden_size
)
{
const
int
tid
=
threadIdx
.
x
;
const
int
stride
=
blockDim
.
x
;
const
int64_t
token_idx
=
blockIdx
.
x
;
const
float
scale
=
*
scale_ptr
;
// Must be performed using 64-bit math to avoid integer overflow.
o
ut
+
=
token_idx
*
hidden_size
;
input
+
=
token_idx
*
hidden_size
;
const
scalar_t
*
row_in
=
inp
ut
+
token_idx
*
hidden_size
;
in
t8_t
*
row_out
=
out
put
+
token_idx
*
hidden_size
;
for
(
int
i
=
tid
;
i
<
hidden_size
;
i
+=
blockDim
.
x
)
{
out
[
i
]
=
float_to_int8_rn
(
static_cast
<
float
>
(
input
[
i
])
/
scale
);
}
vectorize_with_alignment
<
16
>
(
row_in
,
row_out
,
hidden_size
,
tid
,
stride
,
[
=
]
__device__
(
int8_t
&
dst
,
const
scalar_t
&
src
)
{
dst
=
float_to_int8_rn
(
static_cast
<
float
>
(
src
)
/
scale
);
});
}
template
<
typename
scalar_t
,
typename
scale_t
ype
,
typename
azp_t
ype
>
template
<
typename
scalar_t
,
typename
scale_t
,
typename
azp_t
>
__global__
void
static_scaled_int8_azp_quant_kernel
(
scalar_t
const
*
__restrict__
input
,
int8_t
*
__restrict__
out
,
scale_type
const
*
scale_ptr
,
azp_type
const
*
azp_ptr
,
const
int
hidden_size
)
{
int
const
tid
=
threadIdx
.
x
;
int64_t
const
token_idx
=
blockIdx
.
x
;
scale_type
const
scale
=
*
scale_ptr
;
azp_type
const
azp
=
*
azp_ptr
;
const
scalar_t
*
__restrict__
input
,
int8_t
*
__restrict__
output
,
const
scale_t
*
scale_ptr
,
const
azp_t
*
azp_ptr
,
const
int
hidden_size
)
{
const
int
tid
=
threadIdx
.
x
;
const
int
stride
=
blockDim
.
x
;
const
int64_t
token_idx
=
blockIdx
.
x
;
const
float
scale
=
*
scale_ptr
;
const
azp_t
azp
=
*
azp_ptr
;
const
float
inv_s
=
1.0
f
/
scale
;
// Must be performed using 64-bit math to avoid integer overflow.
out
+=
token_idx
*
hidden_size
;
input
+=
token_idx
*
hidden_size
;
for
(
int
i
=
tid
;
i
<
hidden_size
;
i
+=
blockDim
.
x
)
{
auto
const
val
=
static_cast
<
float
>
(
input
[
i
]);
auto
const
quant_val
=
int32_to_int8
(
float_to_int32_rn
(
val
/
scale
)
+
azp
);
out
[
i
]
=
quant_val
;
}
const
scalar_t
*
row_in
=
input
+
token_idx
*
hidden_size
;
int8_t
*
row_out
=
output
+
token_idx
*
hidden_size
;
vectorize_with_alignment
<
16
>
(
row_in
,
row_out
,
hidden_size
,
tid
,
stride
,
[
=
]
__device__
(
int8_t
&
dst
,
const
scalar_t
&
src
)
{
const
auto
v
=
static_cast
<
float
>
(
src
)
*
inv_s
;
dst
=
int32_to_int8
(
float_to_int32_rn
(
v
)
+
azp
);
});
}
template
<
typename
scalar_t
,
typename
scale_t
ype
>
template
<
typename
scalar_t
,
typename
scale_t
>
__global__
void
dynamic_scaled_int8_quant_kernel
(
scalar_t
const
*
__restrict__
input
,
int8_t
*
__restrict__
out
,
scale_type
*
scale
,
const
int
hidden_size
)
{
int
const
tid
=
threadIdx
.
x
;
int64_t
const
token_idx
=
blockIdx
.
x
;
float
absmax_val
=
0.0
f
;
float
const
zero
=
0.0
f
;
const
scalar_t
*
__restrict__
input
,
int8_t
*
__restrict__
output
,
scale_t
*
scale_out
,
const
int
hidden_size
)
{
const
int
tid
=
threadIdx
.
x
;
const
int
stride
=
blockDim
.
x
;
const
int64_t
token_idx
=
blockIdx
.
x
;
// Must be performed using 64-bit math to avoid integer overflow.
o
ut
+
=
token_idx
*
hidden_size
;
input
+
=
token_idx
*
hidden_size
;
for
(
int
i
=
tid
;
i
<
hidden_size
;
i
+=
blockDim
.
x
)
{
float
val
=
static_cast
<
float
>
(
input
[
i
])
;
v
al
=
val
>
zero
?
val
:
-
val
;
absmax_val
=
val
>
absmax_val
?
val
:
absmax_val
;
}
using
BlockReduce
=
cub
::
BlockReduce
<
float
,
1024
>
;
__shared__
typename
BlockReduce
::
TempStorage
reduceStorage
;
float
const
block_absmax_val_maybe
=
BlockReduce
(
reduceStorage
).
Reduce
(
absmax_val
,
cub
::
Max
{},
blockDim
.
x
);
__shared__
float
block_
absmax
_val
;
const
scalar_t
*
row_in
=
inp
ut
+
token_idx
*
hidden_size
;
in
t8_t
*
row_out
=
out
put
+
token_idx
*
hidden_size
;
// calculate for absmax
float
thread_max
=
0.
f
;
v
ectorize_read_with_alignment
<
16
>
(
row_in
,
hidden_size
,
tid
,
stride
,
[
&
]
__device__
(
const
scalar_t
&
src
)
{
const
float
v
=
fabsf
(
static_cast
<
float
>
(
src
));
thread_max
=
fmaxf
(
thread_max
,
v
);
})
;
using
BlockReduce
=
cub
::
BlockReduce
<
float
,
256
>
;
__shared__
typename
BlockReduce
::
TempStorage
tmp
;
float
block_max
=
BlockReduce
(
tmp
).
Reduce
(
thread_max
,
cub
::
Max
{},
blockDim
.
x
);
__shared__
float
absmax
;
if
(
tid
==
0
)
{
block_
absmax
_val
=
block_
absmax_val_maybe
;
scale
[
token_idx
]
=
block_
absmax
_val
/
127.
0
f
;
absmax
=
block_
max
;
scale
_out
[
blockIdx
.
x
]
=
absmax
/
127.
f
;
}
__syncthreads
();
float
const
tmp_scale
=
127.0
f
/
block_absmax_val
;
for
(
int
i
=
tid
;
i
<
hidden_size
;
i
+=
blockDim
.
x
)
{
out
[
i
]
=
float_to_int8_rn
(
static_cast
<
float
>
(
input
[
i
])
*
tmp_scale
);
}
float
inv_s
=
(
absmax
==
0.
f
)
?
0.
f
:
127.
f
/
absmax
;
// 2. quantize
vectorize_with_alignment
<
16
>
(
row_in
,
row_out
,
hidden_size
,
tid
,
stride
,
[
=
]
__device__
(
int8_t
&
dst
,
const
scalar_t
&
src
)
{
dst
=
float_to_int8_rn
(
static_cast
<
float
>
(
src
)
*
inv_s
);
});
}
template
<
typename
scalar_t
,
typename
scale_type
,
typename
azp_type
>
__global__
void
dynamic_scaled_int8_azp_quant_kernel
(
scalar_t
const
*
__restrict__
input
,
int8_t
*
__restrict__
out
,
scale_type
*
scale
,
azp_type
*
azp
,
const
int
hidden_size
)
{
int64_t
const
token_idx
=
blockIdx
.
x
;
// MinMax structure to hold min and max values in one go
struct
MinMax
{
float
min
,
max
;
// Must be performed using 64-bit math to avoid integer overflow.
out
+=
token_idx
*
hidden_size
;
input
+=
token_idx
*
hidden_size
;
// Scan for the min and max value for this token
float
max_val
=
std
::
numeric_limits
<
float
>::
min
();
float
min_val
=
std
::
numeric_limits
<
float
>::
max
();
for
(
int
i
=
threadIdx
.
x
;
i
<
hidden_size
;
i
+=
blockDim
.
x
)
{
auto
val
=
static_cast
<
float
>
(
input
[
i
]
);
max
_val
=
std
::
max
(
max
_val
,
v
al
);
min_val
=
std
::
min
(
min_val
,
val
)
;
__host__
__device__
MinMax
()
:
min
(
std
::
numeric_limits
<
float
>::
max
()),
max
(
std
::
numeric_limits
<
float
>::
lowest
())
{}
__host__
__device__
explicit
MinMax
(
float
v
)
:
min
(
v
),
max
(
v
)
{}
// add a value to the MinMax
__host__
__device__
MinMax
&
operator
+=
(
float
v
)
{
min
=
fminf
(
min
,
v
);
max
=
f
max
f
(
max
,
v
);
return
*
this
;
}
// Reduce the max and min values across the block
using
BlockReduce
=
cub
::
BlockReduce
<
float
,
1024
>
;
__shared__
typename
BlockReduce
::
TempStorage
reduceStorage
;
max_val
=
BlockReduce
(
reduceStorage
).
Reduce
(
max_val
,
cub
::
Max
{},
blockDim
.
x
);
__syncthreads
();
// Make sure min doesn't mess with max shared memory
min_val
=
BlockReduce
(
reduceStorage
).
Reduce
(
min_val
,
cub
::
Min
{},
blockDim
.
x
);
__shared__
scale_type
scale_sh
;
__shared__
azp_type
azp_sh
;
// Compute the scale and zero point and store them, only on the first thread
if
(
threadIdx
.
x
==
0
)
{
float
const
scale_val
=
(
max_val
-
min_val
)
/
255.0
f
;
// Use rounding to even (same as torch.round)
auto
const
azp_float
=
std
::
nearbyint
(
-
128.0
f
-
min_val
/
scale_val
);
auto
const
azp_val
=
static_cast
<
azp_type
>
(
azp_float
);
// Store the scale and azp into shared and global
scale
[
token_idx
]
=
scale_sh
=
scale_val
;
azp
[
token_idx
]
=
azp_sh
=
azp_val
;
// merge two MinMax objects
__host__
__device__
MinMax
&
operator
&=
(
const
MinMax
&
other
)
{
min
=
fminf
(
min
,
other
.
min
);
max
=
fmaxf
(
max
,
other
.
max
);
return
*
this
;
}
};
// Wait for the scale and azp to be computed
__syncthreads
();
__host__
__device__
inline
MinMax
operator
+
(
MinMax
a
,
float
v
)
{
return
a
+=
v
;
}
__host__
__device__
inline
MinMax
operator
&
(
MinMax
a
,
const
MinMax
&
b
)
{
return
a
&=
b
;
}
float
const
scale_val
=
scale_sh
;
azp_type
const
azp_val
=
azp_sh
;
template
<
typename
scalar_t
,
typename
scale_t
,
typename
azp_t
>
__global__
void
dynamic_scaled_int8_azp_quant_kernel
(
const
scalar_t
*
__restrict__
input
,
int8_t
*
__restrict__
output
,
scale_t
*
scale_out
,
azp_t
*
azp_out
,
const
int
hidden_size
)
{
const
int
tid
=
threadIdx
.
x
;
const
int
stride
=
blockDim
.
x
;
const
int64_t
token_idx
=
blockIdx
.
x
;
// Quantize the values
for
(
int
i
=
threadIdx
.
x
;
i
<
hidden_size
;
i
+=
blockDim
.
x
)
{
auto
const
val
=
static_cast
<
float
>
(
input
[
i
]);
auto
const
quant_val
=
int32_to_int8
(
float_to_int32_rn
(
val
/
scale_val
)
+
azp_val
);
out
[
i
]
=
quant_val
;
// Must be performed using 64-bit math to avoid integer overflow.
const
scalar_t
*
row_in
=
input
+
token_idx
*
hidden_size
;
int8_t
*
row_out
=
output
+
token_idx
*
hidden_size
;
// 1. calculate min & max
MinMax
thread_mm
;
vectorize_read_with_alignment
<
16
>
(
row_in
,
hidden_size
,
tid
,
stride
,
[
&
]
__device__
(
const
scalar_t
&
src
)
{
thread_mm
+=
static_cast
<
float
>
(
src
);
});
using
BlockReduce
=
cub
::
BlockReduce
<
MinMax
,
256
>
;
__shared__
typename
BlockReduce
::
TempStorage
tmp
;
MinMax
mm
=
BlockReduce
(
tmp
).
Reduce
(
thread_mm
,
[]
__device__
(
MinMax
a
,
const
MinMax
&
b
)
{
a
&=
b
;
return
a
;
},
blockDim
.
x
);
__shared__
float
scale_sh
;
__shared__
azp_t
azp_sh
;
if
(
tid
==
0
)
{
float
s
=
(
mm
.
max
-
mm
.
min
)
/
255.
f
;
float
zp
=
nearbyintf
(
-
128.
f
-
mm
.
min
/
s
);
// round-to-even
scale_sh
=
s
;
azp_sh
=
azp_t
(
zp
);
scale_out
[
blockIdx
.
x
]
=
s
;
azp_out
[
blockIdx
.
x
]
=
azp_sh
;
}
__syncthreads
();
const
float
inv_s
=
1.
f
/
scale_sh
;
const
azp_t
azp
=
azp_sh
;
// 2. quantize
vectorize_with_alignment
<
16
>
(
row_in
,
row_out
,
hidden_size
,
tid
,
stride
,
[
=
]
__device__
(
int8_t
&
dst
,
const
scalar_t
&
src
)
{
const
auto
v
=
static_cast
<
float
>
(
src
)
*
inv_s
;
dst
=
int32_to_int8
(
float_to_int32_rn
(
v
)
+
azp
);
});
}
}
// namespace vllm
...
...
@@ -247,7 +287,7 @@ void static_scaled_int8_quant(torch::Tensor& out, // [..., hidden_size]
int
const
hidden_size
=
input
.
size
(
-
1
);
int
const
num_tokens
=
input
.
numel
()
/
hidden_size
;
dim3
const
grid
(
num_tokens
);
dim3
const
block
(
std
::
min
(
hidden_size
,
1024
));
dim3
const
block
(
std
::
min
(
hidden_size
,
256
));
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
VLLM_DISPATCH_FLOATING_TYPES
(
input
.
scalar_type
(),
"static_scaled_int8_quant_kernel"
,
[
&
]
{
...
...
@@ -278,7 +318,7 @@ void dynamic_scaled_int8_quant(
int
const
hidden_size
=
input
.
size
(
-
1
);
int
const
num_tokens
=
input
.
numel
()
/
hidden_size
;
dim3
const
grid
(
num_tokens
);
dim3
const
block
(
std
::
min
(
hidden_size
,
1024
));
dim3
const
block
(
std
::
min
(
hidden_size
,
256
));
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
VLLM_DISPATCH_FLOATING_TYPES
(
input
.
scalar_type
(),
"dynamic_scaled_int8_quant_kernel"
,
[
&
]
{
...
...
csrc/quantization/cutlass_w8a8/c3x/scaled_mm.cuh
View file @
99324e25
...
...
@@ -51,7 +51,8 @@ struct cutlass_3x_gemm {
// These are the minimum alignments needed for the kernels to compile
static
constexpr
int
AlignmentAB
=
128
/
cutlass
::
sizeof_bits
<
ElementAB
>::
value
;
static
constexpr
int
AlignmentCD
=
4
;
static
constexpr
int
AlignmentCD
=
128
/
cutlass
::
sizeof_bits
<
ElementD
>::
value
;
using
CollectiveEpilogue
=
typename
cutlass
::
epilogue
::
collective
::
CollectiveBuilder
<
...
...
@@ -144,4 +145,65 @@ struct cutlass_3x_gemm_sm100 {
Shape
<
int
,
int
,
int
,
int
>
,
CollectiveMainloop
,
CollectiveEpilogue
,
void
>
;
};
template
<
typename
ElementAB_
,
typename
ElementD_
,
template
<
typename
,
typename
,
typename
>
typename
Epilogue_
,
typename
TileShape
,
typename
ClusterShape
,
typename
KernelSchedule
,
typename
EpilogueSchedule
>
struct
cutlass_3x_gemm_sm120
{
using
ElementAB
=
ElementAB_
;
using
LayoutA
=
cutlass
::
layout
::
RowMajor
;
static
constexpr
int
AlignmentA
=
128
/
cutlass
::
sizeof_bits
<
ElementAB
>::
value
;
using
LayoutB
=
cutlass
::
layout
::
ColumnMajor
;
static
constexpr
int
AlignmentB
=
128
/
cutlass
::
sizeof_bits
<
ElementAB
>::
value
;
using
ElementC
=
void
;
using
LayoutC
=
cutlass
::
layout
::
RowMajor
;
static
constexpr
int
AlignmentC
=
128
/
cutlass
::
sizeof_bits
<
ElementD_
>::
value
;
using
ElementD
=
ElementD_
;
using
LayoutD
=
cutlass
::
layout
::
RowMajor
;
static
constexpr
int
AlignmentD
=
AlignmentC
;
using
ElementAcc
=
typename
std
::
conditional
<
std
::
is_same_v
<
ElementAB
,
int8_t
>
,
int32_t
,
float
>::
type
;
using
Epilogue
=
Epilogue_
<
ElementAcc
,
ElementD
,
TileShape
>
;
// MMA type
using
ElementAccumulator
=
float
;
// Epilogue types
using
ElementBias
=
cutlass
::
half_t
;
using
ElementCompute
=
float
;
using
ElementAux
=
ElementD
;
using
LayoutAux
=
LayoutD
;
using
ElementAmax
=
float
;
using
EVTCompute
=
typename
Epilogue
::
EVTCompute
;
using
CollectiveEpilogue
=
typename
cutlass
::
epilogue
::
collective
::
CollectiveBuilder
<
cutlass
::
arch
::
Sm120
,
cutlass
::
arch
::
OpClassTensorOp
,
TileShape
,
ClusterShape
,
cutlass
::
epilogue
::
collective
::
EpilogueTileAuto
,
ElementAccumulator
,
ElementCompute
,
ElementC
,
LayoutC
,
AlignmentC
,
ElementD
,
LayoutD
,
AlignmentD
,
EpilogueSchedule
,
EVTCompute
>::
CollectiveOp
;
using
CollectiveMainloop
=
typename
cutlass
::
gemm
::
collective
::
CollectiveBuilder
<
cutlass
::
arch
::
Sm120
,
cutlass
::
arch
::
OpClassTensorOp
,
ElementAB
,
LayoutA
,
AlignmentA
,
ElementAB
,
LayoutB
,
AlignmentB
,
ElementAccumulator
,
TileShape
,
ClusterShape
,
cutlass
::
gemm
::
collective
::
StageCountAutoCarveout
<
static_cast
<
int
>
(
sizeof
(
typename
CollectiveEpilogue
::
SharedStorage
))
>
,
KernelSchedule
>::
CollectiveOp
;
using
GemmKernel
=
cutlass
::
gemm
::
kernel
::
GemmUniversal
<
Shape
<
int
,
int
,
int
,
int
>
,
CollectiveMainloop
,
CollectiveEpilogue
,
void
>
;
};
}
// namespace vllm
csrc/quantization/cutlass_w8a8/c3x/scaled_mm_kernels.hpp
View file @
99324e25
...
...
@@ -36,6 +36,12 @@ void cutlass_scaled_mm_sm100_fp8(torch::Tensor& out, torch::Tensor const& a,
torch
::
Tensor
const
&
b_scales
,
std
::
optional
<
torch
::
Tensor
>
const
&
bias
);
void
cutlass_scaled_mm_sm120_fp8
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
,
std
::
optional
<
torch
::
Tensor
>
const
&
bias
);
void
cutlass_scaled_mm_blockwise_sm100_fp8
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
...
...
csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm100_fp8_dispatch.cuh
View file @
99324e25
...
...
@@ -15,11 +15,11 @@ using c3x::cutlass_gemm_caller;
template
<
typename
InType
,
typename
OutType
,
template
<
typename
,
typename
,
typename
>
typename
Epilogue
>
struct
sm100_fp8_config_default
{
// M in (
128
, inf)
// M in (
256
, inf)
static_assert
(
std
::
is_same
<
InType
,
cutlass
::
float_e4m3_t
>
());
using
KernelSchedule
=
cutlass
::
gemm
::
collective
::
KernelScheduleAuto
;
using
EpilogueSchedule
=
cutlass
::
epilogue
::
collective
::
EpilogueScheduleAuto
;
using
TileShape
=
Shape
<
_256
,
_128
,
_
64
>
;
using
TileShape
=
Shape
<
_256
,
_128
,
_
128
>
;
using
ClusterShape
=
Shape
<
_2
,
_2
,
_1
>
;
using
Cutlass3xGemm
=
cutlass_3x_gemm_sm100
<
InType
,
OutType
,
Epilogue
,
TileShape
,
ClusterShape
,
...
...
@@ -28,13 +28,13 @@ struct sm100_fp8_config_default {
template
<
typename
InType
,
typename
OutType
,
template
<
typename
,
typename
,
typename
>
typename
Epilogue
>
struct
sm100_fp8_config_M
128
{
// M in (64,
128
]
struct
sm100_fp8_config_M
256
{
// M in (64,
256
]
static_assert
(
std
::
is_same
<
InType
,
cutlass
::
float_e4m3_t
>
());
using
KernelSchedule
=
cutlass
::
gemm
::
collective
::
KernelScheduleAuto
;
using
EpilogueSchedule
=
cutlass
::
epilogue
::
collective
::
EpilogueScheduleAuto
;
using
TileShape
=
Shape
<
_128
,
_128
,
_
64
>
;
using
ClusterShape
=
Shape
<
_2
,
_
2
,
_1
>
;
using
TileShape
=
Shape
<
_128
,
_128
,
_
128
>
;
using
ClusterShape
=
Shape
<
_2
,
_
1
,
_1
>
;
using
Cutlass3xGemm
=
cutlass_3x_gemm_sm100
<
InType
,
OutType
,
Epilogue
,
TileShape
,
ClusterShape
,
KernelSchedule
,
EpilogueSchedule
>
;
...
...
@@ -43,12 +43,26 @@ struct sm100_fp8_config_M128 {
template
<
typename
InType
,
typename
OutType
,
template
<
typename
,
typename
,
typename
>
typename
Epilogue
>
struct
sm100_fp8_config_M64
{
// M in [1, 64]
// M in (16, 64]
static_assert
(
std
::
is_same
<
InType
,
cutlass
::
float_e4m3_t
>
());
using
KernelSchedule
=
cutlass
::
gemm
::
collective
::
KernelScheduleAuto
;
using
EpilogueSchedule
=
cutlass
::
epilogue
::
collective
::
EpilogueScheduleAuto
;
using
TileShape
=
Shape
<
_64
,
_64
,
_128
>
;
using
ClusterShape
=
Shape
<
_1
,
_1
,
_1
>
;
using
Cutlass3xGemm
=
cutlass_3x_gemm_sm100
<
InType
,
OutType
,
Epilogue
,
TileShape
,
ClusterShape
,
KernelSchedule
,
EpilogueSchedule
>
;
};
template
<
typename
InType
,
typename
OutType
,
template
<
typename
,
typename
,
typename
>
typename
Epilogue
>
struct
sm100_fp8_config_M16
{
// M in [1, 16]
static_assert
(
std
::
is_same
<
InType
,
cutlass
::
float_e4m3_t
>
());
using
KernelSchedule
=
cutlass
::
gemm
::
collective
::
KernelScheduleAuto
;
using
EpilogueSchedule
=
cutlass
::
epilogue
::
collective
::
EpilogueScheduleAuto
;
using
TileShape
=
Shape
<
_64
,
_64
,
_
256
>
;
using
ClusterShape
=
Shape
<
_1
,
_
8
,
_1
>
;
using
TileShape
=
Shape
<
_64
,
_64
,
_
128
>
;
using
ClusterShape
=
Shape
<
_1
,
_
4
,
_1
>
;
using
Cutlass3xGemm
=
cutlass_3x_gemm_sm100
<
InType
,
OutType
,
Epilogue
,
TileShape
,
ClusterShape
,
KernelSchedule
,
EpilogueSchedule
>
;
...
...
@@ -68,25 +82,31 @@ inline void cutlass_gemm_sm100_fp8_dispatch(torch::Tensor& out,
using
Cutlass3xGemmDefault
=
typename
sm100_fp8_config_default
<
InType
,
OutType
,
Epilogue
>::
Cutlass3xGemm
;
using
Cutlass3xGemmM16
=
typename
sm100_fp8_config_M16
<
InType
,
OutType
,
Epilogue
>::
Cutlass3xGemm
;
using
Cutlass3xGemmM64
=
typename
sm100_fp8_config_M64
<
InType
,
OutType
,
Epilogue
>::
Cutlass3xGemm
;
using
Cutlass3xGemmM
128
=
typename
sm100_fp8_config_M
128
<
InType
,
OutType
,
Epilogue
>::
Cutlass3xGemm
;
using
Cutlass3xGemmM
256
=
typename
sm100_fp8_config_M
256
<
InType
,
OutType
,
Epilogue
>::
Cutlass3xGemm
;
uint32_t
const
m
=
a
.
size
(
0
);
uint32_t
const
mp2
=
std
::
max
(
static_cast
<
uint32_t
>
(
6
4
),
next_pow_2
(
m
));
// next power of 2
std
::
max
(
static_cast
<
uint32_t
>
(
1
6
),
next_pow_2
(
m
));
// next power of 2
if
(
mp2
<=
64
)
{
// m in [1, 64]
if
(
mp2
<=
16
)
{
// m in [1, 16]
return
cutlass_gemm_caller
<
Cutlass3xGemmM16
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
else
if
(
mp2
<=
64
)
{
// m in (16, 64]
return
cutlass_gemm_caller
<
Cutlass3xGemmM64
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
else
if
(
mp2
<=
128
)
{
// m in (64,
128
]
return
cutlass_gemm_caller
<
Cutlass3xGemmM
128
>
(
}
else
if
(
mp2
<=
256
)
{
// m in (64,
256
]
return
cutlass_gemm_caller
<
Cutlass3xGemmM
256
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
else
{
// m in (
128
, inf)
// m in (
256
, inf)
return
cutlass_gemm_caller
<
Cutlass3xGemmDefault
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
...
...
csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm120_fp8.cu
0 → 100644
View file @
99324e25
#include "scaled_mm_kernels.hpp"
#include "scaled_mm_sm120_fp8_dispatch.cuh"
#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp"
namespace
vllm
{
void
cutlass_scaled_mm_sm120_fp8
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
,
std
::
optional
<
torch
::
Tensor
>
const
&
bias
)
{
TORCH_CHECK
(
a_scales
.
is_contiguous
()
&&
b_scales
.
is_contiguous
());
if
(
bias
)
{
TORCH_CHECK
(
bias
->
dtype
()
==
out
.
dtype
(),
"currently bias dtype must match output dtype "
,
out
.
dtype
());
return
cutlass_scaled_mm_sm120_fp8_epilogue
<
c3x
::
ScaledEpilogueBias
>
(
out
,
a
,
b
,
a_scales
,
b_scales
,
*
bias
);
}
else
{
return
cutlass_scaled_mm_sm120_fp8_epilogue
<
c3x
::
ScaledEpilogue
>
(
out
,
a
,
b
,
a_scales
,
b_scales
);
}
}
}
// namespace vllm
csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm120_fp8_dispatch.cuh
0 → 100644
View file @
99324e25
#pragma once
#include "scaled_mm.cuh"
#include "cutlass_gemm_caller.cuh"
/**
* This file defines Gemm kernel configurations for SM120 (fp8) based on the
* Gemm shape.
*/
namespace
vllm
{
using
c3x
::
cutlass_gemm_caller
;
template
<
typename
InType
,
typename
OutType
,
template
<
typename
,
typename
,
typename
>
typename
Epilogue
>
struct
sm120_fp8_config_default
{
static_assert
(
std
::
is_same
<
InType
,
cutlass
::
float_e4m3_t
>
());
using
KernelSchedule
=
cutlass
::
gemm
::
collective
::
KernelScheduleAuto
;
using
EpilogueSchedule
=
cutlass
::
epilogue
::
collective
::
EpilogueScheduleAuto
;
using
TileShape
=
Shape
<
_128
,
_128
,
_128
>
;
using
ClusterShape
=
Shape
<
_1
,
_1
,
_1
>
;
// Only work with Shape<_1, _1, _1>
using
Cutlass3xGemm
=
cutlass_3x_gemm_sm120
<
InType
,
OutType
,
Epilogue
,
TileShape
,
ClusterShape
,
KernelSchedule
,
EpilogueSchedule
>
;
};
template
<
typename
InType
,
typename
OutType
,
template
<
typename
,
typename
,
typename
>
typename
Epilogue
,
typename
...
EpilogueArgs
>
inline
void
cutlass_gemm_sm120_fp8_dispatch
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
EpilogueArgs
&&
...
args
)
{
static_assert
(
std
::
is_same
<
InType
,
cutlass
::
float_e4m3_t
>
());
TORCH_CHECK
(
a
.
dtype
()
==
torch
::
kFloat8_e4m3fn
);
TORCH_CHECK
(
b
.
dtype
()
==
torch
::
kFloat8_e4m3fn
);
using
Cutlass3xGemmDefault
=
typename
sm120_fp8_config_default
<
InType
,
OutType
,
Epilogue
>::
Cutlass3xGemm
;
return
cutlass_gemm_caller
<
Cutlass3xGemmDefault
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
template
<
template
<
typename
,
typename
,
typename
>
typename
Epilogue
,
typename
...
EpilogueArgs
>
void
cutlass_scaled_mm_sm120_fp8_epilogue
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
EpilogueArgs
&&
...
epilogue_args
)
{
TORCH_CHECK
(
a
.
dtype
()
==
torch
::
kFloat8_e4m3fn
);
TORCH_CHECK
(
b
.
dtype
()
==
torch
::
kFloat8_e4m3fn
);
if
(
out
.
dtype
()
==
torch
::
kBFloat16
)
{
return
cutlass_gemm_sm120_fp8_dispatch
<
cutlass
::
float_e4m3_t
,
cutlass
::
bfloat16_t
,
Epilogue
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
epilogue_args
)...);
}
else
{
TORCH_CHECK
(
out
.
dtype
()
==
torch
::
kFloat16
);
return
cutlass_gemm_sm120_fp8_dispatch
<
cutlass
::
float_e4m3_t
,
cutlass
::
half_t
,
Epilogue
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
epilogue_args
)...);
}
}
}
// namespace vllm
\ No newline at end of file
csrc/quantization/cutlass_w8a8/moe/blockwise_scaled_group_mm_sm100.cu
0 → 100644
View file @
99324e25
#include "core/registration.h"
#include <torch/all.h>
#include <cutlass/arch/arch.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <c10/cuda/CUDAStream.h>
#include "cute/tensor.hpp"
#include "cutlass/tensor_ref.h"
#include "cutlass/epilogue/collective/default_epilogue.hpp"
#include "cutlass/epilogue/thread/linear_combination.h"
#include "cutlass/gemm/dispatch_policy.hpp"
#include "cutlass/gemm/group_array_problem_shape.hpp"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "cutlass/epilogue/collective/collective_builder.hpp"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include "cutlass/gemm/kernel/gemm_universal.hpp"
#include "cutlass/util/command_line.h"
#include "cutlass/util/distribution.h"
#include "cutlass/util/host_tensor.h"
#include "cutlass/util/packed_stride.hpp"
#include "cutlass/util/tensor_view_io.h"
#include "cutlass/util/reference/device/gemm.h"
#include "cutlass/util/reference/device/tensor_compare.h"
#include "cutlass/util/reference/host/tensor_fill.h"
#include "cutlass/util/reference/host/gett.hpp"
#include "cutlass/util/reference/host/tensor_norm.h"
#include "cutlass/util/reference/host/tensor_compare.h"
#include <cassert>
using
namespace
cute
;
template
<
typename
ElementAB
,
typename
ElementC
,
typename
ElementAccumulator
,
typename
LayoutSFA
,
typename
LayoutSFB
,
typename
ScaleConfig
>
__global__
void
get_ggemm_starts
(
int32_t
*
expert_offsets
,
ElementAB
**
a_offsets
,
ElementAB
**
b_offsets
,
ElementC
**
out_offsets
,
ElementAccumulator
**
a_scale_offsets
,
ElementAccumulator
**
b_scale_offsets
,
ElementAB
*
a_base_as_int
,
ElementAB
*
b_base_as_int
,
ElementC
*
out_base_as_int
,
ElementAccumulator
*
a_scale_base_as_int
,
ElementAccumulator
*
b_scale_base_as_int
,
LayoutSFA
*
layout_sfa_base_as_int
,
LayoutSFB
*
layout_sfb_base_as_int
,
int
*
problem_sizes
)
{
int
expert_id
=
threadIdx
.
x
;
if
(
expert_id
>=
gridDim
.
x
*
blockDim
.
x
)
{
return
;
}
int
m
=
problem_sizes
[
expert_id
*
3
];
int
n
=
problem_sizes
[
expert_id
*
3
+
1
];
int
k
=
problem_sizes
[
expert_id
*
3
+
2
];
int32_t
expert_offset
=
expert_offsets
[
expert_id
];
int
a_stride
=
expert_offset
*
k
;
int
b_stride
=
expert_id
*
k
*
n
;
int
a_scale_stride
=
expert_offset
*
k
/
128
;
int
b_scale_stride
=
expert_id
*
k
*
n
/
128
/
128
;
a_offsets
[
expert_id
]
=
a_base_as_int
+
a_stride
;
b_offsets
[
expert_id
]
=
b_base_as_int
+
b_stride
;
out_offsets
[
expert_id
]
=
out_base_as_int
+
expert_offset
*
n
;
a_scale_offsets
[
expert_id
]
=
a_scale_base_as_int
+
a_scale_stride
;
b_scale_offsets
[
expert_id
]
=
b_scale_base_as_int
+
b_scale_stride
;
LayoutSFA
*
layout_sfa_ptr
=
layout_sfa_base_as_int
+
expert_id
;
LayoutSFB
*
layout_sfb_ptr
=
layout_sfb_base_as_int
+
expert_id
;
*
layout_sfa_ptr
=
ScaleConfig
::
tile_atom_to_shape_SFA
(
cute
::
make_shape
(
m
,
n
,
k
,
1
));
*
layout_sfb_ptr
=
ScaleConfig
::
tile_atom_to_shape_SFB
(
cute
::
make_shape
(
m
,
n
,
k
,
1
));
}
#define __CALL_GET_STARTS_KERNEL(TENSOR_C_TYPE, C_TYPE, LayoutSFA, LayoutSFB, \
ScaleConfig) \
else if (out_tensors.dtype() == TENSOR_C_TYPE) { \
get_ggemm_starts<cutlass::float_e4m3_t, C_TYPE, float, LayoutSFA, \
LayoutSFB, ScaleConfig><<<1, num_experts, 0, stream>>>( \
static_cast<int32_t*>(expert_offsets.data_ptr()), \
static_cast<cutlass::float_e4m3_t**>(a_ptrs.data_ptr()), \
static_cast<cutlass::float_e4m3_t**>(b_ptrs.data_ptr()), \
static_cast<C_TYPE**>(out_ptrs.data_ptr()), \
static_cast<float**>(a_scales_ptrs.data_ptr()), \
static_cast<float**>(b_scales_ptrs.data_ptr()), \
static_cast<cutlass::float_e4m3_t*>(a_tensors.data_ptr()), \
static_cast<cutlass::float_e4m3_t*>(b_tensors.data_ptr()), \
static_cast<C_TYPE*>(out_tensors.data_ptr()), \
static_cast<float*>(a_scales.data_ptr()), \
static_cast<float*>(b_scales.data_ptr()), \
reinterpret_cast<LayoutSFA*>(layout_sfa.data_ptr()), \
reinterpret_cast<LayoutSFB*>(layout_sfb.data_ptr()), \
static_cast<int*>(problem_sizes.data_ptr())); \
}
template
<
typename
LayoutSFA
,
typename
LayoutSFB
,
typename
ScaleConfig
>
void
run_get_ggemm_starts
(
torch
::
Tensor
const
&
expert_offsets
,
torch
::
Tensor
&
a_ptrs
,
torch
::
Tensor
&
b_ptrs
,
torch
::
Tensor
&
out_ptrs
,
torch
::
Tensor
&
a_scales_ptrs
,
torch
::
Tensor
&
b_scales_ptrs
,
torch
::
Tensor
const
&
a_tensors
,
torch
::
Tensor
const
&
b_tensors
,
torch
::
Tensor
out_tensors
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
,
torch
::
Tensor
const
&
layout_sfa
,
torch
::
Tensor
const
&
layout_sfb
,
torch
::
Tensor
const
&
problem_sizes
)
{
TORCH_CHECK
(
a_tensors
.
dtype
()
==
torch
::
kFloat8_e4m3fn
);
TORCH_CHECK
(
b_tensors
.
dtype
()
==
torch
::
kFloat8_e4m3fn
);
TORCH_CHECK
(
a_scales
.
dtype
()
==
torch
::
kFloat32
);
TORCH_CHECK
(
b_scales
.
dtype
()
==
torch
::
kFloat32
);
TORCH_CHECK
(
out_tensors
.
size
(
1
)
%
128
==
0
or
out_tensors
.
size
(
0
)
%
128
==
0
);
TORCH_CHECK
(
a_tensors
.
size
(
1
)
%
128
==
0
or
a_tensors
.
size
(
0
)
%
128
==
0
);
int
num_experts
=
(
int
)
expert_offsets
.
size
(
0
);
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
(
a_tensors
.
device
().
index
());
if
(
false
)
{
}
__CALL_GET_STARTS_KERNEL
(
torch
::
kBFloat16
,
cutlass
::
bfloat16_t
,
LayoutSFA
,
LayoutSFB
,
ScaleConfig
)
__CALL_GET_STARTS_KERNEL
(
torch
::
kFloat16
,
cutlass
::
half_t
,
LayoutSFA
,
LayoutSFB
,
ScaleConfig
)
else
{
TORCH_CHECK
(
false
,
"Unsupported output tensor type"
);
}
}
template
<
typename
OutType
,
typename
ScheduleConfig
,
typename
LayoutD
>
void
run_blockwise_scaled_group_mm
(
torch
::
Tensor
&
out_ptrs
,
const
torch
::
Tensor
&
a_ptrs
,
const
torch
::
Tensor
&
b_ptrs
,
const
torch
::
Tensor
&
a_scales_ptrs
,
const
torch
::
Tensor
&
b_scales_ptrs
,
const
torch
::
Tensor
&
stride_a
,
const
torch
::
Tensor
&
stride_b
,
const
torch
::
Tensor
&
stride_c
,
const
torch
::
Tensor
&
layout_sfa
,
const
torch
::
Tensor
&
layout_sfb
,
const
torch
::
Tensor
&
problem_sizes
,
const
torch
::
Tensor
&
expert_offsets
)
{
using
ProblemShape
=
cutlass
::
gemm
::
GroupProblemShape
<
Shape
<
int
,
int
,
int
>>
;
// Types
using
ElementA
=
cutlass
::
float_e4m3_t
;
using
ElementB
=
cutlass
::
float_e4m3_t
;
using
ElementC
=
OutType
;
using
ElementD
=
ElementC
;
using
ElementAccumulator
=
float
;
using
LayoutA
=
cutlass
::
layout
::
RowMajor
;
using
LayoutB
=
cutlass
::
layout
::
ColumnMajor
;
using
LayoutC
=
LayoutD
;
// Alignments
static
constexpr
int
AlignmentA
=
128
/
cutlass
::
sizeof_bits
<
ElementA
>::
value
;
static
constexpr
int
AlignmentB
=
128
/
cutlass
::
sizeof_bits
<
ElementB
>::
value
;
static
constexpr
int
AlignmentC
=
128
/
cutlass
::
sizeof_bits
<
ElementC
>::
value
;
using
ArchTag
=
cutlass
::
arch
::
Sm100
;
using
OperatorClass
=
cutlass
::
arch
::
OpClassTensorOp
;
using
CollectiveEpilogue
=
typename
cutlass
::
epilogue
::
collective
::
CollectiveBuilder
<
ArchTag
,
OperatorClass
,
typename
ScheduleConfig
::
MmaTileShape
,
typename
ScheduleConfig
::
ClusterShape
,
cutlass
::
epilogue
::
collective
::
EpilogueTileAuto
,
ElementAccumulator
,
ElementAccumulator
,
void
,
LayoutC
*
,
AlignmentC
,
ElementD
,
LayoutC
*
,
AlignmentC
,
typename
ScheduleConfig
::
EpilogueSchedule
>::
CollectiveOp
;
using
CollectiveMainloop
=
typename
cutlass
::
gemm
::
collective
::
CollectiveBuilder
<
ArchTag
,
OperatorClass
,
ElementA
,
cute
::
tuple
<
LayoutA
*
,
typename
ScheduleConfig
::
LayoutSFA
*>
,
AlignmentA
,
ElementB
,
cute
::
tuple
<
LayoutB
*
,
typename
ScheduleConfig
::
LayoutSFB
*>
,
AlignmentB
,
ElementAccumulator
,
typename
ScheduleConfig
::
MmaTileShape
,
typename
ScheduleConfig
::
ClusterShape
,
cutlass
::
gemm
::
collective
::
StageCountAutoCarveout
<
static_cast
<
int
>
(
sizeof
(
typename
CollectiveEpilogue
::
SharedStorage
))
>
,
typename
ScheduleConfig
::
KernelSchedule
>::
CollectiveOp
;
using
GemmKernel
=
cutlass
::
gemm
::
kernel
::
GemmUniversal
<
ProblemShape
,
CollectiveMainloop
,
CollectiveEpilogue
,
void
>
;
using
Gemm
=
cutlass
::
gemm
::
device
::
GemmUniversalAdapter
<
GemmKernel
>
;
using
StrideA
=
typename
Gemm
::
GemmKernel
::
InternalStrideA
;
using
StrideB
=
typename
Gemm
::
GemmKernel
::
InternalStrideB
;
using
StrideC
=
typename
Gemm
::
GemmKernel
::
InternalStrideC
;
using
StrideD
=
typename
Gemm
::
GemmKernel
::
InternalStrideD
;
using
UnderlyingProblemShape
=
ProblemShape
::
UnderlyingProblemShape
;
int
num_experts
=
(
int
)
expert_offsets
.
size
(
0
);
Gemm
gemm_op
;
// Mainloop Arguments
typename
GemmKernel
::
MainloopArguments
mainloop_args
{
static_cast
<
const
ElementA
**>
(
a_ptrs
.
data_ptr
()),
static_cast
<
StrideA
*>
(
stride_a
.
data_ptr
()),
static_cast
<
const
ElementB
**>
(
b_ptrs
.
data_ptr
()),
static_cast
<
StrideB
*>
(
stride_b
.
data_ptr
()),
static_cast
<
const
ElementAccumulator
**>
(
a_scales_ptrs
.
data_ptr
()),
reinterpret_cast
<
typename
ScheduleConfig
::
LayoutSFA
*>
(
layout_sfa
.
data_ptr
()),
static_cast
<
const
ElementAccumulator
**>
(
b_scales_ptrs
.
data_ptr
()),
reinterpret_cast
<
typename
ScheduleConfig
::
LayoutSFB
*>
(
layout_sfb
.
data_ptr
())};
cutlass
::
KernelHardwareInfo
hw_info
;
hw_info
.
device_id
=
a_ptrs
.
get_device
();
hw_info
.
sm_count
=
cutlass
::
KernelHardwareInfo
::
query_device_multiprocessor_count
(
hw_info
.
device_id
);
// Epilogue Arguments
typename
GemmKernel
::
EpilogueArguments
epilogue_args
{
{},
// epilogue.thread
nullptr
,
static_cast
<
StrideC
*>
(
stride_c
.
data_ptr
()),
static_cast
<
ElementD
**>
(
out_ptrs
.
data_ptr
()),
static_cast
<
StrideC
*>
(
stride_c
.
data_ptr
())};
UnderlyingProblemShape
*
problem_sizes_as_shapes
=
static_cast
<
UnderlyingProblemShape
*>
(
problem_sizes
.
data_ptr
());
// Gemm Arguments
typename
GemmKernel
::
Arguments
args
{
cutlass
::
gemm
::
GemmUniversalMode
::
kGrouped
,
{
num_experts
,
problem_sizes_as_shapes
,
nullptr
},
mainloop_args
,
epilogue_args
,
hw_info
};
at
::
cuda
::
CUDAGuard
device_guard
{(
char
)
a_ptrs
.
device
().
index
()};
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
(
a_ptrs
.
get_device
());
auto
can_implement_status
=
gemm_op
.
can_implement
(
args
);
TORCH_CHECK
(
can_implement_status
==
cutlass
::
Status
::
kSuccess
,
"Failed to implement GEMM"
);
size_t
workspace_size
=
gemm_op
.
get_workspace_size
(
args
);
auto
const
workspace_options
=
torch
::
TensorOptions
().
dtype
(
torch
::
kUInt8
).
device
(
a_ptrs
.
device
());
auto
workspace
=
torch
::
empty
(
workspace_size
,
workspace_options
);
auto
status
=
gemm_op
.
initialize
(
args
,
workspace
.
data_ptr
(),
stream
);
TORCH_CHECK
(
status
==
cutlass
::
Status
::
kSuccess
,
"Failed to initialize GEMM"
);
status
=
gemm_op
.
run
(
stream
);
TORCH_CHECK
(
status
==
cutlass
::
Status
::
kSuccess
,
"Failed to run GEMM"
);
}
template
<
typename
OutType
>
void
blockwise_scaled_group_mm_dispatch_shape
(
torch
::
Tensor
&
output
,
const
torch
::
Tensor
&
a
,
const
torch
::
Tensor
&
b
,
const
torch
::
Tensor
&
scales_a
,
const
torch
::
Tensor
&
scales_b
,
const
torch
::
Tensor
&
problem_sizes
,
const
torch
::
Tensor
&
expert_offsets
)
{
struct
MmaConfig
{
using
ElementA
=
cutlass
::
float_e4m3_t
;
using
KernelSchedule
=
cutlass
::
gemm
::
KernelPtrArrayTmaWarpSpecializedBlockwise1SmSm100
;
using
EpilogueSchedule
=
cutlass
::
epilogue
::
PtrArrayTmaWarpSpecialized1Sm
;
using
ScaleConfig
=
cutlass
::
detail
::
Sm100BlockwiseScaleConfig
<
1
,
128
,
128
,
cute
::
UMMA
::
Major
::
K
,
cute
::
UMMA
::
Major
::
K
>
;
using
LayoutSFA
=
decltype
(
ScaleConfig
::
deduce_layoutSFA
());
using
LayoutSFB
=
decltype
(
ScaleConfig
::
deduce_layoutSFB
());
using
LayoutC
=
cutlass
::
layout
::
RowMajor
;
using
MmaTileShape
=
Shape
<
_128
,
_128
,
_128
>
;
using
ClusterShape
=
Shape
<
_1
,
_1
,
_1
>
;
};
int
num_experts
=
(
int
)
expert_offsets
.
size
(
0
);
auto
a_ptrs
=
torch
::
empty
(
{
num_experts
},
torch
::
TensorOptions
().
dtype
(
torch
::
kInt64
).
device
(
a
.
device
()));
auto
b_ptrs
=
torch
::
empty
(
{
num_experts
},
torch
::
TensorOptions
().
dtype
(
torch
::
kInt64
).
device
(
a
.
device
()));
auto
out_ptrs
=
torch
::
empty
(
{
num_experts
},
torch
::
TensorOptions
().
dtype
(
torch
::
kInt64
).
device
(
a
.
device
()));
auto
a_scales_ptrs
=
torch
::
empty
(
{
num_experts
},
torch
::
TensorOptions
().
dtype
(
torch
::
kInt64
).
device
(
a
.
device
()));
auto
b_scales_ptrs
=
torch
::
empty
(
{
num_experts
},
torch
::
TensorOptions
().
dtype
(
torch
::
kInt64
).
device
(
a
.
device
()));
auto
layout_sfa
=
torch
::
empty
(
{
num_experts
,
5
},
torch
::
TensorOptions
().
dtype
(
torch
::
kInt32
).
device
(
a
.
device
()));
auto
layout_sfb
=
torch
::
empty
(
{
num_experts
,
5
},
torch
::
TensorOptions
().
dtype
(
torch
::
kInt32
).
device
(
a
.
device
()));
auto
stride_a
=
torch
::
full
(
{
num_experts
},
a
.
size
(
1
),
torch
::
TensorOptions
().
dtype
(
torch
::
kInt64
).
device
(
a
.
device
()));
auto
stride_b
=
torch
::
full
(
{
num_experts
},
a
.
size
(
1
),
torch
::
TensorOptions
().
dtype
(
torch
::
kInt64
).
device
(
a
.
device
()));
auto
stride_c
=
torch
::
full
(
{
num_experts
},
output
.
size
(
1
),
torch
::
TensorOptions
().
dtype
(
torch
::
kInt64
).
device
(
a
.
device
()));
torch
::
TensorOptions
options_int
=
torch
::
TensorOptions
().
dtype
(
torch
::
kInt64
).
device
(
a
.
device
());
run_get_ggemm_starts
<
typename
MmaConfig
::
LayoutSFA
,
typename
MmaConfig
::
LayoutSFB
,
typename
MmaConfig
::
ScaleConfig
>
(
expert_offsets
,
a_ptrs
,
b_ptrs
,
out_ptrs
,
a_scales_ptrs
,
b_scales_ptrs
,
a
,
b
,
output
,
scales_a
,
scales_b
,
layout_sfa
,
layout_sfb
,
problem_sizes
);
run_blockwise_scaled_group_mm
<
OutType
,
MmaConfig
,
typename
MmaConfig
::
LayoutC
>
(
out_ptrs
,
a_ptrs
,
b_ptrs
,
a_scales_ptrs
,
b_scales_ptrs
,
stride_a
,
stride_b
,
stride_c
,
layout_sfa
,
layout_sfb
,
problem_sizes
,
expert_offsets
);
}
void
cutlass_blockwise_scaled_grouped_mm
(
torch
::
Tensor
&
output
,
const
torch
::
Tensor
&
a
,
const
torch
::
Tensor
&
b
,
const
torch
::
Tensor
&
scales_a
,
const
torch
::
Tensor
&
scales_b
,
const
torch
::
Tensor
&
problem_sizes
,
const
torch
::
Tensor
&
expert_offsets
)
{
TORCH_CHECK
(
problem_sizes
.
dim
()
==
2
,
"problem_sizes must be 2D tensor"
);
TORCH_CHECK
(
problem_sizes
.
size
(
1
)
==
3
,
"problem_sizes must have shape (num_experts, 3)"
);
TORCH_CHECK
(
problem_sizes
.
size
(
0
)
==
expert_offsets
.
size
(
0
),
"Number of experts in problem_sizes must match expert_offsets"
);
TORCH_CHECK
(
problem_sizes
.
dtype
()
==
torch
::
kInt32
,
"problem_sizes must be int32"
);
TORCH_CHECK
(
a
.
scalar_type
()
==
torch
::
kFloat8_e4m3fn
,
"a must be kFloat8_e4m3fn"
);
TORCH_CHECK
(
b
.
scalar_type
()
==
torch
::
kFloat8_e4m3fn
,
"b must be kFloat8_e4m3fn"
);
TORCH_CHECK
(
output
.
scalar_type
()
==
torch
::
kBFloat16
||
output
.
scalar_type
()
==
torch
::
kHalf
,
"output must be bfloat16 or half"
);
TORCH_CHECK
(
scales_a
.
scalar_type
()
==
torch
::
kFloat32
,
"scales_a must be float32"
);
TORCH_CHECK
(
scales_b
.
scalar_type
()
==
torch
::
kFloat32
,
"scales_b must be float32"
);
TORCH_CHECK
(
expert_offsets
.
scalar_type
()
==
torch
::
kInt32
,
"expert_offsets must be int32"
);
TORCH_CHECK
(
output
.
dim
()
==
2
,
"output must be 2D tensor"
);
TORCH_CHECK
(
a
.
dim
()
==
2
,
"a must be 2D tensor"
);
TORCH_CHECK
(
b
.
dim
()
==
3
,
"b must be 3D tensor"
);
TORCH_CHECK
(
scales_a
.
dim
()
==
2
,
"scales_a must be 2D tensor"
);
TORCH_CHECK
(
scales_b
.
dim
()
==
3
,
"scales_b must be 3D tensor"
);
TORCH_CHECK
(
problem_sizes
.
dim
()
==
2
,
"problem_sizes must be 2D tensor"
);
TORCH_CHECK
(
problem_sizes
.
size
(
1
)
==
3
,
"problem_sizes must have shape (num_experts, 3)"
);
TORCH_CHECK
(
problem_sizes
.
size
(
0
)
==
expert_offsets
.
size
(
0
),
"Number of experts in problem_sizes must match expert_offsets"
);
TORCH_CHECK
(
problem_sizes
.
dtype
()
==
torch
::
kInt32
,
"problem_sizes must be int32"
);
TORCH_CHECK
(
expert_offsets
.
dim
()
==
1
,
"expert_offsets must be 1D tensor"
);
#if defined(ENABLE_CUTLASS_MOE_SM100) && ENABLE_CUTLASS_MOE_SM100
if
(
output
.
scalar_type
()
==
torch
::
kBFloat16
)
{
blockwise_scaled_group_mm_dispatch_shape
<
cutlass
::
bfloat16_t
>
(
output
,
a
,
b
,
scales_a
,
scales_b
,
problem_sizes
,
expert_offsets
);
}
else
if
(
output
.
scalar_type
()
==
torch
::
kFloat16
)
{
blockwise_scaled_group_mm_dispatch_shape
<
cutlass
::
half_t
>
(
output
,
a
,
b
,
scales_a
,
scales_b
,
problem_sizes
,
expert_offsets
);
}
else
{
TORCH_CHECK
(
false
,
"Unsupported output tensor type"
);
}
#endif
}
TORCH_LIBRARY_IMPL_EXPAND
(
TORCH_EXTENSION_NAME
,
CUDA
,
m
)
{
m
.
impl
(
"cutlass_blockwise_scaled_grouped_mm"
,
&
cutlass_blockwise_scaled_grouped_mm
);
}
csrc/quantization/cutlass_w8a8/scaled_mm_c3x_sm120.cu
0 → 100644
View file @
99324e25
#include <cudaTypedefs.h>
#include "c3x/scaled_mm_kernels.hpp"
#include "cuda_utils.h"
/*
This file defines quantized GEMM operations using the CUTLASS 3.x API, for
NVIDIA GPUs with sm120 (Blackwell Geforce).
*/
#if defined ENABLE_SCALED_MM_SM120 && ENABLE_SCALED_MM_SM120
void
cutlass_scaled_mm_sm120
(
torch
::
Tensor
&
c
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
,
std
::
optional
<
torch
::
Tensor
>
const
&
bias
)
{
TORCH_CHECK
(
a_scales
.
dtype
()
==
torch
::
kFloat32
);
TORCH_CHECK
(
b_scales
.
dtype
()
==
torch
::
kFloat32
);
int
M
=
a
.
size
(
0
),
N
=
b
.
size
(
1
),
K
=
a
.
size
(
1
);
TORCH_CHECK
(
(
a_scales
.
numel
()
==
1
||
a_scales
.
numel
()
==
a
.
size
(
0
))
&&
(
b_scales
.
numel
()
==
1
||
b_scales
.
numel
()
==
b
.
size
(
1
)),
"Currently, block scaled fp8 gemm is not implemented for Blackwell"
);
// Standard per-tensor/per-token/per-channel scaling
TORCH_CHECK
(
a_scales
.
is_contiguous
()
&&
b_scales
.
is_contiguous
());
TORCH_CHECK
(
a
.
dtype
()
==
torch
::
kFloat8_e4m3fn
,
"Currently, only fp8 gemm is implemented for Blackwell"
);
vllm
::
cutlass_scaled_mm_sm120_fp8
(
c
,
a
,
b
,
a_scales
,
b_scales
,
bias
);
}
#endif
csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu
View file @
99324e25
...
...
@@ -41,6 +41,14 @@ void cutlass_moe_mm_sm90(
#endif
#if defined ENABLE_SCALED_MM_SM120 && ENABLE_SCALED_MM_SM120
void
cutlass_scaled_mm_sm120
(
torch
::
Tensor
&
c
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
,
std
::
optional
<
torch
::
Tensor
>
const
&
bias
);
#endif
#if defined ENABLE_SCALED_MM_SM100 && ENABLE_SCALED_MM_SM100
void
cutlass_scaled_mm_sm100
(
torch
::
Tensor
&
c
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
...
...
@@ -168,8 +176,15 @@ void cutlass_scaled_mm(torch::Tensor& c, torch::Tensor const& a,
at
::
cuda
::
OptionalCUDAGuard
const
device_guard
(
device_of
(
a
));
int32_t
version_num
=
get_sm_version_num
();
#if defined ENABLE_SCALED_MM_SM120 && ENABLE_SCALED_MM_SM120
if
(
version_num
>=
120
)
{
cutlass_scaled_mm_sm120
(
c
,
a
,
b
,
a_scales
,
b_scales
,
bias
);
return
;
}
#endif
#if defined ENABLE_SCALED_MM_SM100 && ENABLE_SCALED_MM_SM100
if
(
version_num
>=
100
)
{
if
(
version_num
>=
100
&&
version_num
<
120
)
{
cutlass_scaled_mm_sm100
(
c
,
a
,
b
,
a_scales
,
b_scales
,
bias
);
return
;
}
...
...
@@ -241,7 +256,7 @@ void get_cutlass_moe_mm_data(
// mm to run it for.
int32_t
version_num
=
get_sm_version_num
();
#if (defined ENABLE_CUTLASS_MOE_SM90 && ENABLE_CUTLASS_MOE_SM90) || \
(defined ENABLE_
SCALED_MM
_SM100 && ENABLE_
SCALED_MM
_SM
9
0)
(defined ENABLE_
CUTLASS_MOE
_SM100 && ENABLE_
CUTLASS_MOE
_SM
10
0)
get_cutlass_moe_mm_data_caller
(
topk_ids
,
expert_offsets
,
problem_sizes1
,
problem_sizes2
,
input_permutation
,
output_permutation
,
num_experts
,
n
,
k
,
...
...
@@ -252,7 +267,7 @@ void get_cutlass_moe_mm_data(
false
,
"No compiled get_cutlass_moe_mm_data: no cutlass_scaled_mm kernel for "
"CUDA device capability: "
,
version_num
,
". Required capability: 90"
);
version_num
,
". Required capability: 90
or 100
"
);
}
void
get_cutlass_pplx_moe_mm_data
(
torch
::
Tensor
&
expert_offsets
,
...
...
@@ -265,7 +280,8 @@ void get_cutlass_pplx_moe_mm_data(torch::Tensor& expert_offsets,
// This function currently gets compiled only if we have a valid cutlass moe
// mm to run it for.
int32_t
version_num
=
get_sm_version_num
();
#if defined ENABLE_CUTLASS_MOE_SM90 && ENABLE_CUTLASS_MOE_SM90
#if (defined ENABLE_CUTLASS_MOE_SM90 && ENABLE_CUTLASS_MOE_SM90) || \
(defined ENABLE_CUTLASS_MOE_SM100 && ENABLE_CUTLASS_MOE_SM100)
get_cutlass_pplx_moe_mm_data_caller
(
expert_offsets
,
problem_sizes1
,
problem_sizes2
,
expert_num_tokens
,
num_local_experts
,
padded_m
,
n
,
k
);
...
...
@@ -275,7 +291,7 @@ void get_cutlass_pplx_moe_mm_data(torch::Tensor& expert_offsets,
false
,
"No compiled get_cutlass_pplx_moe_mm_data: no cutlass_scaled_mm kernel "
"for CUDA device capability: "
,
version_num
,
". Required capability: 90"
);
version_num
,
". Required capability: 90
or 100
"
);
}
void
cutlass_scaled_mm_azp
(
torch
::
Tensor
&
c
,
torch
::
Tensor
const
&
a
,
...
...
csrc/quantization/fp4/nvfp4_blockwise_moe_kernel.cu
View file @
99324e25
...
...
@@ -335,8 +335,10 @@ void run_fp4_blockwise_scaled_group_mm(
TORCH_CHECK
(
status
==
cutlass
::
Status
::
kSuccess
,
"Failed to run GEMM"
);
}
#if defined ENABLE_NVFP4 && ENABLE_NVFP4
constexpr
auto
FLOAT4_E2M1X2
=
at
::
ScalarType
::
Byte
;
constexpr
auto
SF_DTYPE
=
at
::
ScalarType
::
Float8_e4m3fn
;
#endif
#define CHECK_TYPE(x, st, m) \
TORCH_CHECK(x.scalar_type() == st, ": Inconsistency of Tensor type:", m)
...
...
csrc/quantization/fp4/nvfp4_experts_quant.cu
View file @
99324e25
...
...
@@ -231,7 +231,7 @@ __device__ uint32_t cvt_warp_fp16_to_fp4(PackedVec<Type>& vec, float SFScaleVal,
}
// Use UE4M3 by default.
template
<
class
Type
,
bool
UE8M0_SF
=
false
>
template
<
class
Type
,
bool
UE8M0_SF
=
false
,
bool
SMALL_NUM_EXPERTS
=
false
>
__global__
void
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
__launch_bounds__
(
512
,
4
)
cvt_fp16_to_fp4
(
...
...
@@ -240,7 +240,7 @@ cvt_fp16_to_fp4(
#endif
int32_t
numRows
,
int32_t
numCols
,
Type
const
*
in
,
float
const
*
SFScale
,
uint32_t
*
out
,
uint32_t
*
SFout
,
uint32_t
*
input_offset_by_experts
,
uint32_t
*
output_scale_offset_by_experts
,
int
n_experts
)
{
uint32_t
*
output_scale_offset_by_experts
,
int
n_experts
,
bool
low_latency
)
{
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
using
PackedVec
=
PackedVec
<
Type
>
;
static
constexpr
int
CVT_FP4_NUM_THREADS_PER_SF
=
...
...
@@ -248,49 +248,182 @@ cvt_fp16_to_fp4(
static_assert
(
sizeof
(
PackedVec
)
==
sizeof
(
Type
)
*
CVT_FP4_ELTS_PER_THREAD
,
"Vec size is not matched."
);
// Input tensor row/col loops.
for
(
int
rowIdx
=
blockIdx
.
x
;
rowIdx
<
numRows
;
rowIdx
+=
gridDim
.
x
)
{
for
(
int
colIdx
=
threadIdx
.
x
;
colIdx
<
numCols
/
CVT_FP4_ELTS_PER_THREAD
;
colIdx
+=
blockDim
.
x
)
{
int64_t
inOffset
=
rowIdx
*
(
numCols
/
CVT_FP4_ELTS_PER_THREAD
)
+
colIdx
;
PackedVec
in_vec
=
reinterpret_cast
<
PackedVec
const
*>
(
in
)[
inOffset
];
// Get the output tensor offset.
// Same as inOffset because 8 elements are packed into one uint32_t.
int64_t
outOffset
=
inOffset
;
auto
&
out_pos
=
out
[
outOffset
];
// Find index within the experts.
int
rowIdx_in_expert
=
0
;
int
expert_idx
=
0
;
int
tid
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
int
colsPerRow
=
numCols
/
CVT_FP4_ELTS_PER_THREAD
;
// Each global thread processes one element
for
(
int
globalIdx
=
tid
;
globalIdx
<
numRows
*
colsPerRow
;
globalIdx
+=
gridDim
.
x
*
blockDim
.
x
)
{
// Calculate which row and column this global thread should process
int
rowIdx
=
globalIdx
/
colsPerRow
;
int
colIdx
=
globalIdx
%
colsPerRow
;
int64_t
inOffset
=
rowIdx
*
colsPerRow
+
colIdx
;
PackedVec
in_vec
=
reinterpret_cast
<
PackedVec
const
*>
(
in
)[
inOffset
];
// Get the output tensor offset.
// Same as inOffset because 8 elements are packed into one uint32_t.
int64_t
outOffset
=
inOffset
;
auto
&
out_pos
=
out
[
outOffset
];
// Find index within the experts using different strategies based on expert
// count
int
rowIdx_in_expert
=
0
;
int
expert_idx
=
0
;
if
constexpr
(
SMALL_NUM_EXPERTS
)
{
for
(
int
i
=
0
;
i
<
n_experts
;
i
++
)
{
if
(
rowIdx
>=
input_offset_by_experts
[
i
]
&&
rowIdx
<
input_offset_by_experts
[
i
+
1
])
{
rowIdx_in_expert
=
rowIdx
-
input_offset_by_experts
[
i
];
uint32_t
current_offset
=
__ldca
(
&
input_offset_by_experts
[
i
]);
uint32_t
next_offset
=
__ldca
(
&
input_offset_by_experts
[
i
+
1
]);
if
(
rowIdx
>=
current_offset
&&
rowIdx
<
next_offset
)
{
rowIdx_in_expert
=
rowIdx
-
current_offset
;
expert_idx
=
i
;
break
;
}
}
}
else
{
// Load input offsets into registers first, then do the computation.
// Local array size set to 17 because of register limit.
uint32_t
local_offsets
[
17
];
for
(
int
chunk_start
=
0
;
chunk_start
<
n_experts
;
chunk_start
+=
16
)
{
*
reinterpret_cast
<
int4
*>
(
local_offsets
)
=
__ldca
(
reinterpret_cast
<
const
int4
*>
(
&
input_offset_by_experts
[
chunk_start
]));
*
reinterpret_cast
<
int4
*>
(
local_offsets
+
4
)
=
__ldca
(
reinterpret_cast
<
const
int4
*>
(
&
input_offset_by_experts
[
chunk_start
+
4
]));
*
reinterpret_cast
<
int4
*>
(
local_offsets
+
8
)
=
__ldca
(
reinterpret_cast
<
const
int4
*>
(
&
input_offset_by_experts
[
chunk_start
+
8
]));
*
reinterpret_cast
<
int4
*>
(
local_offsets
+
12
)
=
__ldca
(
reinterpret_cast
<
const
int4
*>
(
&
input_offset_by_experts
[
chunk_start
+
12
]));
local_offsets
[
16
]
=
__ldca
(
&
input_offset_by_experts
[
chunk_start
+
16
]);
// Check against the 16 loaded offsets
#pragma unroll
for
(
int
i
=
0
;
i
<
16
;
i
++
)
{
if
(
rowIdx
>=
local_offsets
[
i
]
&&
rowIdx
<
local_offsets
[
i
+
1
])
{
rowIdx_in_expert
=
rowIdx
-
local_offsets
[
i
];
expert_idx
=
chunk_start
+
i
;
break
;
}
}
}
}
// Get the global scaling factor, which will be applied to the SF.
// Note SFScale is the same as next GEMM's alpha, which is
// (448.f / (Alpha_A / 6.f)).
float
const
SFScaleVal
=
SFScale
==
nullptr
?
1.0
f
:
SFScale
[
expert_idx
];
int
factor
=
CVT_FP4_SF_VEC_SIZE
*
4
;
// The actual output_scales dim is computed from the padded numCols.
int32_t
numCols_padded
=
(
numCols
+
factor
-
1
)
/
factor
*
factor
;
int
numCols_SFout
=
numCols_padded
/
CVT_FP4_SF_VEC_SIZE
/
4
;
uint32_t
*
SFout_in_expert
=
SFout
+
output_scale_offset_by_experts
[
expert_idx
]
*
numCols_SFout
;
auto
sf_out
=
cvt_quant_to_fp4_get_sf_out_offset
<
uint32_t
,
CVT_FP4_NUM_THREADS_PER_SF
>
(
rowIdx_in_expert
,
colIdx
,
numCols
,
SFout_in_expert
);
out_pos
=
cvt_warp_fp16_to_fp4
<
Type
,
UE8M0_SF
>
(
in_vec
,
SFScaleVal
,
sf_out
);
}
#endif
}
// Kernel for LARGE_M_TOPK = true (large m_topk optimized version)
template
<
class
Type
,
bool
UE8M0_SF
=
false
,
bool
SMALL_NUM_EXPERTS
=
false
>
__global__
void
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
__launch_bounds__
(
1024
,
4
)
cvt_fp16_to_fp4
(
#else
cvt_fp16_to_fp4
(
#endif
int32_t
numRows
,
int32_t
numCols
,
Type
const
*
in
,
float
const
*
SFScale
,
uint32_t
*
out
,
uint32_t
*
SFout
,
uint32_t
*
input_offset_by_experts
,
uint32_t
*
output_scale_offset_by_experts
,
int
n_experts
)
{
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
using
PackedVec
=
PackedVec
<
Type
>
;
static
constexpr
int
CVT_FP4_NUM_THREADS_PER_SF
=
(
CVT_FP4_SF_VEC_SIZE
/
CVT_FP4_ELTS_PER_THREAD
);
static_assert
(
sizeof
(
PackedVec
)
==
sizeof
(
Type
)
*
CVT_FP4_ELTS_PER_THREAD
,
"Vec size is not matched."
);
extern
__shared__
uint32_t
shared_input_offsets
[];
// Load input offsets into shared memory.
// If n_experts is larger than 4, use vectorized int4 to save instructions.
// If n_experts is smaller than 4, read directly.
if
constexpr
(
SMALL_NUM_EXPERTS
)
{
for
(
int
i
=
threadIdx
.
x
;
i
<
n_experts
+
1
;
i
+=
blockDim
.
x
)
{
shared_input_offsets
[
i
]
=
input_offset_by_experts
[
i
];
}
}
else
{
for
(
int
i
=
threadIdx
.
x
*
4
;
i
<
n_experts
;
i
+=
blockDim
.
x
*
4
)
{
*
reinterpret_cast
<
int4
*>
(
&
shared_input_offsets
[
i
])
=
*
reinterpret_cast
<
const
int4
*>
(
&
input_offset_by_experts
[
i
]);
}
if
(
threadIdx
.
x
==
0
)
{
shared_input_offsets
[
n_experts
]
=
input_offset_by_experts
[
n_experts
];
}
}
// Get the global scaling factor, which will be applied to the SF.
// Note SFScale is the same as next GEMM's alpha, which is
// (448.f / (Alpha_A / 6.f)).
float
const
SFScaleVal
=
SFScale
==
nullptr
?
1.0
f
:
SFScale
[
expert_idx
];
int
factor
=
CVT_FP4_SF_VEC_SIZE
*
4
;
// The actual output_scales dim is computed from the padded numCols.
int32_t
numCols_padded
=
(
numCols
+
factor
-
1
)
/
factor
*
factor
;
int
numCols_SFout
=
numCols_padded
/
CVT_FP4_SF_VEC_SIZE
/
4
;
uint32_t
*
SFout_in_expert
=
SFout
+
output_scale_offset_by_experts
[
expert_idx
]
*
numCols_SFout
;
auto
sf_out
=
cvt_quant_to_fp4_get_sf_out_offset
<
uint32_t
,
CVT_FP4_NUM_THREADS_PER_SF
>
(
rowIdx_in_expert
,
colIdx
,
numCols
,
SFout_in_expert
);
out_pos
=
cvt_warp_fp16_to_fp4
<
Type
,
UE8M0_SF
>
(
in_vec
,
SFScaleVal
,
sf_out
);
__syncthreads
();
int
tid
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
int
colsPerRow
=
numCols
/
CVT_FP4_ELTS_PER_THREAD
;
// Each global thread processes one element
for
(
int
globalIdx
=
tid
;
globalIdx
<
numRows
*
colsPerRow
;
globalIdx
+=
gridDim
.
x
*
blockDim
.
x
)
{
// Calculate which row and column this global thread should process
int
rowIdx
=
globalIdx
/
colsPerRow
;
int
colIdx
=
globalIdx
%
colsPerRow
;
int64_t
inOffset
=
rowIdx
*
colsPerRow
+
colIdx
;
PackedVec
in_vec
=
reinterpret_cast
<
PackedVec
const
*>
(
in
)[
inOffset
];
int64_t
outOffset
=
inOffset
;
auto
&
out_pos
=
out
[
outOffset
];
// Find expert using binary search for better performance with large m_topk
int
rowIdx_in_expert
=
0
;
int
expert_idx
=
0
;
// Binary search through experts using shared memory
int
left
=
0
,
right
=
n_experts
-
1
;
while
(
left
<=
right
)
{
int
mid
=
(
left
+
right
)
/
2
;
// Get offsets: shared_input_offsets[i] corresponds to
// input_offset_by_experts[i]
uint32_t
mid_offset
=
shared_input_offsets
[
mid
];
uint32_t
next_offset
=
shared_input_offsets
[
mid
+
1
];
if
(
rowIdx
>=
mid_offset
&&
rowIdx
<
next_offset
)
{
rowIdx_in_expert
=
rowIdx
-
mid_offset
;
expert_idx
=
mid
;
break
;
}
else
if
(
rowIdx
<
mid_offset
)
{
right
=
mid
-
1
;
}
else
{
left
=
mid
+
1
;
}
}
float
const
SFScaleVal
=
SFScale
==
nullptr
?
1.0
f
:
SFScale
[
expert_idx
];
int
factor
=
CVT_FP4_SF_VEC_SIZE
*
4
;
int32_t
numCols_padded
=
(
numCols
+
factor
-
1
)
/
factor
*
factor
;
int
numCols_SFout
=
numCols_padded
/
CVT_FP4_SF_VEC_SIZE
/
4
;
uint32_t
*
SFout_in_expert
=
SFout
+
output_scale_offset_by_experts
[
expert_idx
]
*
numCols_SFout
;
auto
sf_out
=
cvt_quant_to_fp4_get_sf_out_offset
<
uint32_t
,
CVT_FP4_NUM_THREADS_PER_SF
>
(
rowIdx_in_expert
,
colIdx
,
numCols
,
SFout_in_expert
);
out_pos
=
cvt_warp_fp16_to_fp4
<
Type
,
UE8M0_SF
>
(
in_vec
,
SFScaleVal
,
sf_out
);
}
#endif
}
...
...
@@ -309,18 +442,63 @@ void quant_impl(void* output, void* output_scale, void* input,
// Grid, Block size.
// Each thread converts 8 values.
dim3
block
(
std
::
min
(
int
(
k
/
ELTS_PER_THREAD
),
512
));
int
const
workSizePerRow
=
k
/
ELTS_PER_THREAD
;
int
const
totalWorkSize
=
m_topk
*
workSizePerRow
;
dim3
block
(
std
::
min
(
workSizePerRow
,
512
));
// Get number of blocks per SM (assume we can fully utilize the SM).
int
const
numBlocksPerSM
=
2048
/
block
.
x
;
dim3
grid
(
std
::
min
(
int
(
m_topk
),
multiProcessorCount
*
numBlocksPerSM
));
cvt_fp16_to_fp4
<
T
,
false
><<<
grid
,
block
,
0
,
stream
>>>
(
m_topk
,
k
,
reinterpret_cast
<
T
*>
(
input
),
reinterpret_cast
<
float
*>
(
input_global_scale
),
reinterpret_cast
<
uint32_t
*>
(
output
),
reinterpret_cast
<
uint32_t
*>
(
output_scale
),
reinterpret_cast
<
uint32_t
*>
(
input_offset_by_experts
),
reinterpret_cast
<
uint32_t
*>
(
output_scale_offset_by_experts
),
n_experts
);
dim3
grid
(
std
::
min
(
static_cast
<
int
>
((
totalWorkSize
+
block
.
x
-
1
)
/
block
.
x
),
multiProcessorCount
*
numBlocksPerSM
));
while
(
grid
.
x
<=
multiProcessorCount
&&
block
.
x
>
64
)
{
grid
.
x
*=
2
;
block
.
x
=
(
block
.
x
+
1
)
/
2
;
}
int
const
blockRepeat
=
(
totalWorkSize
+
block
.
x
*
grid
.
x
-
1
)
/
(
block
.
x
*
grid
.
x
);
if
(
blockRepeat
>
1
)
{
size_t
shared_mem_size
=
(
n_experts
+
1
)
*
sizeof
(
uint32_t
);
if
(
n_experts
>=
4
)
{
cvt_fp16_to_fp4
<
T
,
false
,
false
>
<<<
grid
,
block
,
shared_mem_size
,
stream
>>>
(
m_topk
,
k
,
reinterpret_cast
<
T
*>
(
input
),
reinterpret_cast
<
float
*>
(
input_global_scale
),
reinterpret_cast
<
uint32_t
*>
(
output
),
reinterpret_cast
<
uint32_t
*>
(
output_scale
),
reinterpret_cast
<
uint32_t
*>
(
input_offset_by_experts
),
reinterpret_cast
<
uint32_t
*>
(
output_scale_offset_by_experts
),
n_experts
);
}
else
{
cvt_fp16_to_fp4
<
T
,
false
,
true
><<<
grid
,
block
,
shared_mem_size
,
stream
>>>
(
m_topk
,
k
,
reinterpret_cast
<
T
*>
(
input
),
reinterpret_cast
<
float
*>
(
input_global_scale
),
reinterpret_cast
<
uint32_t
*>
(
output
),
reinterpret_cast
<
uint32_t
*>
(
output_scale
),
reinterpret_cast
<
uint32_t
*>
(
input_offset_by_experts
),
reinterpret_cast
<
uint32_t
*>
(
output_scale_offset_by_experts
),
n_experts
);
}
}
else
{
if
(
n_experts
>=
16
)
{
cvt_fp16_to_fp4
<
T
,
false
,
false
><<<
grid
,
block
,
0
,
stream
>>>
(
m_topk
,
k
,
reinterpret_cast
<
T
*>
(
input
),
reinterpret_cast
<
float
*>
(
input_global_scale
),
reinterpret_cast
<
uint32_t
*>
(
output
),
reinterpret_cast
<
uint32_t
*>
(
output_scale
),
reinterpret_cast
<
uint32_t
*>
(
input_offset_by_experts
),
reinterpret_cast
<
uint32_t
*>
(
output_scale_offset_by_experts
),
n_experts
,
/* bool low_latency */
true
);
}
else
{
cvt_fp16_to_fp4
<
T
,
false
,
true
><<<
grid
,
block
,
0
,
stream
>>>
(
m_topk
,
k
,
reinterpret_cast
<
T
*>
(
input
),
reinterpret_cast
<
float
*>
(
input_global_scale
),
reinterpret_cast
<
uint32_t
*>
(
output
),
reinterpret_cast
<
uint32_t
*>
(
output_scale
),
reinterpret_cast
<
uint32_t
*>
(
input_offset_by_experts
),
reinterpret_cast
<
uint32_t
*>
(
output_scale_offset_by_experts
),
n_experts
,
/* bool low_latency */
true
);
}
}
}
/*Quantization entry for fp4 experts quantization*/
...
...
@@ -383,7 +561,7 @@ void scaled_fp4_experts_quant_sm100a(
TORCH_CHECK
(
output_scale
.
size
(
1
)
*
4
==
padded_k
);
auto
in_dtype
=
input
.
dtype
();
at
::
cuda
::
CUDAGuard
device_guard
{(
char
)
input
.
get_device
()}
;
const
at
::
cuda
::
Optional
CUDAGuard
device_guard
(
device_of
(
input
))
;
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
(
input
.
get_device
());
if
(
in_dtype
==
at
::
ScalarType
::
Half
)
{
...
...
@@ -401,4 +579,4 @@ void scaled_fp4_experts_quant_sm100a(
}
else
{
TORCH_CHECK
(
false
,
"Expected input data type to be half or bfloat16"
);
}
}
\ No newline at end of file
}
csrc/quantization/fp4/nvfp4_quant_kernels.cu
View file @
99324e25
...
...
@@ -347,7 +347,7 @@ void scaled_fp4_quant_sm100a(torch::Tensor const& output,
auto
input_sf_ptr
=
static_cast
<
float
const
*>
(
input_sf
.
data_ptr
());
auto
sf_out
=
static_cast
<
int32_t
*>
(
output_sf
.
data_ptr
());
auto
output_ptr
=
static_cast
<
int64_t
*>
(
output
.
data_ptr
());
at
::
cuda
::
CUDAGuard
device_guard
{(
char
)
input
.
get_device
()}
;
const
at
::
cuda
::
Optional
CUDAGuard
device_guard
(
device_of
(
input
))
;
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
(
input
.
get_device
());
// We don't support e8m0 scales at this moment.
...
...
csrc/quantization/fp4/nvfp4_scaled_mm_kernels.cu
View file @
99324e25
...
...
@@ -267,7 +267,7 @@ void cutlass_scaled_fp4_mm_sm100a(torch::Tensor& D, torch::Tensor const& A,
B_sf
.
sizes
()[
1
],
")"
);
auto
out_dtype
=
D
.
dtype
();
at
::
cuda
::
CUDAGuard
device_guard
{(
char
)
A
.
get_device
()}
;
const
at
::
cuda
::
Optional
CUDAGuard
device_guard
(
device_of
(
A
))
;
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
(
A
.
get_device
());
if
(
out_dtype
==
at
::
ScalarType
::
Half
)
{
...
...
csrc/quantization/fp8/amd/quant_utils.cuh
View file @
99324e25
...
...
@@ -448,8 +448,6 @@ scaled_vec_conversion<uint16_t, uint8_t>(const uint8_t& a, float scale) {
template
<
>
__inline__
__device__
uint32_t
scaled_vec_conversion
<
uint32_t
,
uint16_t
>
(
const
uint16_t
&
a
,
float
scale
)
{
[[
maybe_unused
]]
__half2_raw
h2r
=
__hip_cvt_fp8x2_to_halfraw2
(
a
,
fp8_type
::
__default_interpret
);
union
{
__half2_raw
h2r
;
uint32_t
ui32
;
...
...
csrc/quantization/gguf/gguf_kernel.cu
View file @
99324e25
...
...
@@ -92,111 +92,112 @@ torch::Tensor ggml_mul_mat_vec_a8(torch::Tensor W, // quant weight
torch
::
Tensor
X
,
// input
int64_t
type
,
int64_t
row
)
{
int
col
=
X
.
sizes
()[
1
];
int
vecs
=
X
.
sizes
()[
0
];
const
int
padded
=
(
col
+
512
-
1
)
/
512
*
512
;
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
X
));
auto
options
=
torch
::
TensorOptions
().
dtype
(
X
.
dtype
()).
device
(
W
.
device
());
at
::
Tensor
Y
=
torch
::
empty
({
1
,
row
},
options
);
at
::
Tensor
Y
=
torch
::
empty
({
vecs
,
row
},
options
);
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
().
stream
();
options
=
torch
::
TensorOptions
().
dtype
(
torch
::
kInt32
).
device
(
W
.
device
());
at
::
Tensor
quant_X
=
torch
::
empty
({
1
,
padded
/
32
*
9
},
options
);
at
::
Tensor
quant_X
=
torch
::
empty
({
vecs
,
padded
/
32
*
9
},
options
);
VLLM_DISPATCH_FLOATING_TYPES
(
X
.
scalar_type
(),
"ggml_mul_mat_vec_a8"
,
[
&
]
{
quantize_row_q8_1_cuda
<
scalar_t
>
(
(
scalar_t
*
)
X
.
data_ptr
(),
(
void
*
)
quant_X
.
data_ptr
(),
col
,
1
,
stream
);
quantize_row_q8_1_cuda
<
scalar_t
>
(
(
scalar_t
*
)
X
.
data_ptr
(),
(
void
*
)
quant_X
.
data_ptr
(),
col
,
vecs
,
stream
);
switch
(
type
)
{
case
2
:
mul_mat_vec_q4_0_q8_1_cuda
<
scalar_t
>
(
(
void
*
)
W
.
data_ptr
(),
(
void
*
)
quant_X
.
data_ptr
(),
(
scalar_t
*
)
Y
.
data_ptr
(),
col
,
row
,
stream
);
(
scalar_t
*
)
Y
.
data_ptr
(),
col
,
row
,
vecs
,
stream
);
break
;
case
3
:
mul_mat_vec_q4_1_q8_1_cuda
<
scalar_t
>
(
(
void
*
)
W
.
data_ptr
(),
(
void
*
)
quant_X
.
data_ptr
(),
(
scalar_t
*
)
Y
.
data_ptr
(),
col
,
row
,
stream
);
(
scalar_t
*
)
Y
.
data_ptr
(),
col
,
row
,
vecs
,
stream
);
break
;
case
6
:
mul_mat_vec_q5_0_q8_1_cuda
<
scalar_t
>
(
(
void
*
)
W
.
data_ptr
(),
(
void
*
)
quant_X
.
data_ptr
(),
(
scalar_t
*
)
Y
.
data_ptr
(),
col
,
row
,
stream
);
(
scalar_t
*
)
Y
.
data_ptr
(),
col
,
row
,
vecs
,
stream
);
break
;
case
7
:
mul_mat_vec_q5_1_q8_1_cuda
<
scalar_t
>
(
(
void
*
)
W
.
data_ptr
(),
(
void
*
)
quant_X
.
data_ptr
(),
(
scalar_t
*
)
Y
.
data_ptr
(),
col
,
row
,
stream
);
(
scalar_t
*
)
Y
.
data_ptr
(),
col
,
row
,
vecs
,
stream
);
break
;
case
8
:
mul_mat_vec_q8_0_q8_1_cuda
<
scalar_t
>
(
(
void
*
)
W
.
data_ptr
(),
(
void
*
)
quant_X
.
data_ptr
(),
(
scalar_t
*
)
Y
.
data_ptr
(),
col
,
row
,
stream
);
(
scalar_t
*
)
Y
.
data_ptr
(),
col
,
row
,
vecs
,
stream
);
break
;
case
10
:
mul_mat_vec_q2_K_q8_1_cuda
<
scalar_t
>
(
(
void
*
)
W
.
data_ptr
(),
(
void
*
)
quant_X
.
data_ptr
(),
(
scalar_t
*
)
Y
.
data_ptr
(),
col
,
row
,
stream
);
(
scalar_t
*
)
Y
.
data_ptr
(),
col
,
row
,
vecs
,
stream
);
break
;
case
11
:
mul_mat_vec_q3_K_q8_1_cuda
<
scalar_t
>
(
(
void
*
)
W
.
data_ptr
(),
(
void
*
)
quant_X
.
data_ptr
(),
(
scalar_t
*
)
Y
.
data_ptr
(),
col
,
row
,
stream
);
(
scalar_t
*
)
Y
.
data_ptr
(),
col
,
row
,
vecs
,
stream
);
break
;
case
12
:
mul_mat_vec_q4_K_q8_1_cuda
<
scalar_t
>
(
(
void
*
)
W
.
data_ptr
(),
(
void
*
)
quant_X
.
data_ptr
(),
(
scalar_t
*
)
Y
.
data_ptr
(),
col
,
row
,
stream
);
(
scalar_t
*
)
Y
.
data_ptr
(),
col
,
row
,
vecs
,
stream
);
break
;
case
13
:
mul_mat_vec_q5_K_q8_1_cuda
<
scalar_t
>
(
(
void
*
)
W
.
data_ptr
(),
(
void
*
)
quant_X
.
data_ptr
(),
(
scalar_t
*
)
Y
.
data_ptr
(),
col
,
row
,
stream
);
(
scalar_t
*
)
Y
.
data_ptr
(),
col
,
row
,
vecs
,
stream
);
break
;
case
14
:
mul_mat_vec_q6_K_q8_1_cuda
<
scalar_t
>
(
(
void
*
)
W
.
data_ptr
(),
(
void
*
)
quant_X
.
data_ptr
(),
(
scalar_t
*
)
Y
.
data_ptr
(),
col
,
row
,
stream
);
(
scalar_t
*
)
Y
.
data_ptr
(),
col
,
row
,
vecs
,
stream
);
break
;
case
16
:
mul_mat_vec_iq2_xxs_q8_1_cuda
<
scalar_t
>
(
(
void
*
)
W
.
data_ptr
(),
(
void
*
)
quant_X
.
data_ptr
(),
(
scalar_t
*
)
Y
.
data_ptr
(),
col
,
row
,
stream
);
(
scalar_t
*
)
Y
.
data_ptr
(),
col
,
row
,
vecs
,
stream
);
break
;
case
17
:
mul_mat_vec_iq2_xs_q8_1_cuda
<
scalar_t
>
(
(
void
*
)
W
.
data_ptr
(),
(
void
*
)
quant_X
.
data_ptr
(),
(
scalar_t
*
)
Y
.
data_ptr
(),
col
,
row
,
stream
);
(
scalar_t
*
)
Y
.
data_ptr
(),
col
,
row
,
vecs
,
stream
);
break
;
case
18
:
mul_mat_vec_iq3_xxs_q8_1_cuda
<
scalar_t
>
(
(
void
*
)
W
.
data_ptr
(),
(
void
*
)
quant_X
.
data_ptr
(),
(
scalar_t
*
)
Y
.
data_ptr
(),
col
,
row
,
stream
);
(
scalar_t
*
)
Y
.
data_ptr
(),
col
,
row
,
vecs
,
stream
);
break
;
case
19
:
mul_mat_vec_iq1_s_q8_1_cuda
<
scalar_t
>
(
(
void
*
)
W
.
data_ptr
(),
(
void
*
)
quant_X
.
data_ptr
(),
(
scalar_t
*
)
Y
.
data_ptr
(),
col
,
row
,
stream
);
(
scalar_t
*
)
Y
.
data_ptr
(),
col
,
row
,
vecs
,
stream
);
break
;
case
20
:
mul_mat_vec_iq4_nl_q8_1_cuda
<
scalar_t
>
(
(
void
*
)
W
.
data_ptr
(),
(
void
*
)
quant_X
.
data_ptr
(),
(
scalar_t
*
)
Y
.
data_ptr
(),
col
,
row
,
stream
);
(
scalar_t
*
)
Y
.
data_ptr
(),
col
,
row
,
vecs
,
stream
);
break
;
case
21
:
mul_mat_vec_iq3_s_q8_1_cuda
<
scalar_t
>
(
(
void
*
)
W
.
data_ptr
(),
(
void
*
)
quant_X
.
data_ptr
(),
(
scalar_t
*
)
Y
.
data_ptr
(),
col
,
row
,
stream
);
(
scalar_t
*
)
Y
.
data_ptr
(),
col
,
row
,
vecs
,
stream
);
break
;
case
22
:
mul_mat_vec_iq2_s_q8_1_cuda
<
scalar_t
>
(
(
void
*
)
W
.
data_ptr
(),
(
void
*
)
quant_X
.
data_ptr
(),
(
scalar_t
*
)
Y
.
data_ptr
(),
col
,
row
,
stream
);
(
scalar_t
*
)
Y
.
data_ptr
(),
col
,
row
,
vecs
,
stream
);
break
;
case
23
:
mul_mat_vec_iq4_xs_q8_1_cuda
<
scalar_t
>
(
(
void
*
)
W
.
data_ptr
(),
(
void
*
)
quant_X
.
data_ptr
(),
(
scalar_t
*
)
Y
.
data_ptr
(),
col
,
row
,
stream
);
(
scalar_t
*
)
Y
.
data_ptr
(),
col
,
row
,
vecs
,
stream
);
break
;
case
29
:
mul_mat_vec_iq1_m_q8_1_cuda
<
scalar_t
>
(
(
void
*
)
W
.
data_ptr
(),
(
void
*
)
quant_X
.
data_ptr
(),
(
scalar_t
*
)
Y
.
data_ptr
(),
col
,
row
,
stream
);
(
scalar_t
*
)
Y
.
data_ptr
(),
col
,
row
,
vecs
,
stream
);
break
;
}
});
...
...
csrc/quantization/gguf/mmvq.cuh
View file @
99324e25
// copied and adapted from https://github.com/ggerganov/llama.cpp/blob/b2899/ggml-cuda/mmvq.cu
template
<
typename
scalar_t
,
int
qk
,
int
qi
,
typename
block_q_t
,
int
vdr
,
vec_dot_q_cuda_t
vec_dot_q_cuda
>
static
__global__
void
mul_mat_vec_q
(
const
void
*
__restrict__
vx
,
const
void
*
__restrict__
vy
,
scalar_t
*
__restrict__
dst
,
const
int
ncols
,
const
int
nrows
)
{
static
__global__
void
mul_mat_vec_q
(
const
void
*
__restrict__
vx
,
const
void
*
__restrict__
vy
,
scalar_t
*
__restrict__
dst
,
const
int
ncols
,
const
int
nrows
,
const
int
nvecs
)
{
const
auto
row
=
blockIdx
.
x
*
blockDim
.
y
+
threadIdx
.
y
;
const
auto
vec
=
blockIdx
.
y
;
if
(
row
>=
nrows
)
{
if
(
row
>=
nrows
||
vec
>=
nvecs
)
{
return
;
}
const
int
blocks_per_row
=
ncols
/
qk
;
const
int
blocks_per_warp
=
vdr
*
WARP_SIZE
/
qi
;
const
int
nrows_y
=
(
ncols
+
512
-
1
)
/
512
*
512
;
// partial sum for each thread
// partial sum for each thread
float
tmp
=
0.0
f
;
const
block_q_t
*
x
=
(
const
block_q_t
*
)
vx
;
...
...
@@ -19,7 +22,7 @@ static __global__ void mul_mat_vec_q(const void * __restrict__ vx, const void *
for
(
auto
i
=
threadIdx
.
x
/
(
qi
/
vdr
);
i
<
blocks_per_row
;
i
+=
blocks_per_warp
)
{
const
int
ibx
=
row
*
blocks_per_row
+
i
;
// x block index
const
int
iby
=
i
*
(
qk
/
QK8_1
);
// y block index that aligns with ibx
const
int
iby
=
vec
*
(
nrows_y
/
QK8_1
)
+
i
*
(
qk
/
QK8_1
);
// y block index that aligns with ibx
const
int
iqs
=
vdr
*
(
threadIdx
.
x
%
(
qi
/
vdr
));
// x block quant index when casting the quants to int
...
...
@@ -33,177 +36,177 @@ static __global__ void mul_mat_vec_q(const void * __restrict__ vx, const void *
}
if
(
threadIdx
.
x
==
0
)
{
dst
[
row
]
=
tmp
;
dst
[
vec
*
nrows
+
row
]
=
tmp
;
}
}
template
<
typename
scalar_t
>
static
void
mul_mat_vec_q4_0_q8_1_cuda
(
const
void
*
vx
,
const
void
*
vy
,
scalar_t
*
dst
,
const
int
ncols
,
const
int
nrows
,
cudaStream_t
stream
)
{
static
void
mul_mat_vec_q4_0_q8_1_cuda
(
const
void
*
vx
,
const
void
*
vy
,
scalar_t
*
dst
,
const
int
ncols
,
const
int
nrows
,
const
int
nvecs
,
cudaStream_t
stream
)
{
const
int
block_num_y
=
(
nrows
+
GGML_CUDA_MMV_Y
-
1
)
/
GGML_CUDA_MMV_Y
;
const
dim3
block_nums
(
block_num_y
,
1
,
1
);
const
dim3
block_nums
(
block_num_y
,
nvecs
,
1
);
const
dim3
block_dims
(
WARP_SIZE
,
GGML_CUDA_MMV_Y
,
1
);
mul_mat_vec_q
<
scalar_t
,
QK4_0
,
QI4_0
,
block_q4_0
,
VDR_Q4_0_Q8_1_MMVQ
,
vec_dot_q4_0_q8_1
>
<<<
block_nums
,
block_dims
,
0
,
stream
>>>
(
vx
,
vy
,
dst
,
ncols
,
nrows
);
<<<
block_nums
,
block_dims
,
0
,
stream
>>>
(
vx
,
vy
,
dst
,
ncols
,
nrows
,
nvecs
);
}
template
<
typename
scalar_t
>
static
void
mul_mat_vec_q4_1_q8_1_cuda
(
const
void
*
vx
,
const
void
*
vy
,
scalar_t
*
dst
,
const
int
ncols
,
const
int
nrows
,
cudaStream_t
stream
)
{
static
void
mul_mat_vec_q4_1_q8_1_cuda
(
const
void
*
vx
,
const
void
*
vy
,
scalar_t
*
dst
,
const
int
ncols
,
const
int
nrows
,
const
int
nvecs
,
cudaStream_t
stream
)
{
const
int
block_num_y
=
(
nrows
+
GGML_CUDA_MMV_Y
-
1
)
/
GGML_CUDA_MMV_Y
;
const
dim3
block_nums
(
block_num_y
,
1
,
1
);
const
dim3
block_nums
(
block_num_y
,
nvecs
,
1
);
const
dim3
block_dims
(
WARP_SIZE
,
GGML_CUDA_MMV_Y
,
1
);
mul_mat_vec_q
<
scalar_t
,
QK4_0
,
QI4_1
,
block_q4_1
,
VDR_Q4_1_Q8_1_MMVQ
,
vec_dot_q4_1_q8_1
>
<<<
block_nums
,
block_dims
,
0
,
stream
>>>
(
vx
,
vy
,
dst
,
ncols
,
nrows
);
<<<
block_nums
,
block_dims
,
0
,
stream
>>>
(
vx
,
vy
,
dst
,
ncols
,
nrows
,
nvecs
);
}
template
<
typename
scalar_t
>
static
void
mul_mat_vec_q5_0_q8_1_cuda
(
const
void
*
vx
,
const
void
*
vy
,
scalar_t
*
dst
,
const
int
ncols
,
const
int
nrows
,
cudaStream_t
stream
)
{
static
void
mul_mat_vec_q5_0_q8_1_cuda
(
const
void
*
vx
,
const
void
*
vy
,
scalar_t
*
dst
,
const
int
ncols
,
const
int
nrows
,
const
int
nvecs
,
cudaStream_t
stream
)
{
const
int
block_num_y
=
(
nrows
+
GGML_CUDA_MMV_Y
-
1
)
/
GGML_CUDA_MMV_Y
;
const
dim3
block_nums
(
block_num_y
,
1
,
1
);
const
dim3
block_nums
(
block_num_y
,
nvecs
,
1
);
const
dim3
block_dims
(
WARP_SIZE
,
GGML_CUDA_MMV_Y
,
1
);
mul_mat_vec_q
<
scalar_t
,
QK5_0
,
QI5_0
,
block_q5_0
,
VDR_Q5_0_Q8_1_MMVQ
,
vec_dot_q5_0_q8_1
>
<<<
block_nums
,
block_dims
,
0
,
stream
>>>
(
vx
,
vy
,
dst
,
ncols
,
nrows
);
<<<
block_nums
,
block_dims
,
0
,
stream
>>>
(
vx
,
vy
,
dst
,
ncols
,
nrows
,
nvecs
);
}
template
<
typename
scalar_t
>
static
void
mul_mat_vec_q5_1_q8_1_cuda
(
const
void
*
vx
,
const
void
*
vy
,
scalar_t
*
dst
,
const
int
ncols
,
const
int
nrows
,
cudaStream_t
stream
)
{
static
void
mul_mat_vec_q5_1_q8_1_cuda
(
const
void
*
vx
,
const
void
*
vy
,
scalar_t
*
dst
,
const
int
ncols
,
const
int
nrows
,
const
int
nvecs
,
cudaStream_t
stream
)
{
const
int
block_num_y
=
(
nrows
+
GGML_CUDA_MMV_Y
-
1
)
/
GGML_CUDA_MMV_Y
;
const
dim3
block_nums
(
block_num_y
,
1
,
1
);
const
dim3
block_nums
(
block_num_y
,
nvecs
,
1
);
const
dim3
block_dims
(
WARP_SIZE
,
GGML_CUDA_MMV_Y
,
1
);
mul_mat_vec_q
<
scalar_t
,
QK5_1
,
QI5_1
,
block_q5_1
,
VDR_Q5_1_Q8_1_MMVQ
,
vec_dot_q5_1_q8_1
>
<<<
block_nums
,
block_dims
,
0
,
stream
>>>
(
vx
,
vy
,
dst
,
ncols
,
nrows
);
<<<
block_nums
,
block_dims
,
0
,
stream
>>>
(
vx
,
vy
,
dst
,
ncols
,
nrows
,
nvecs
);
}
template
<
typename
scalar_t
>
static
void
mul_mat_vec_q8_0_q8_1_cuda
(
const
void
*
vx
,
const
void
*
vy
,
scalar_t
*
dst
,
const
int
ncols
,
const
int
nrows
,
cudaStream_t
stream
)
{
static
void
mul_mat_vec_q8_0_q8_1_cuda
(
const
void
*
vx
,
const
void
*
vy
,
scalar_t
*
dst
,
const
int
ncols
,
const
int
nrows
,
const
int
nvecs
,
cudaStream_t
stream
)
{
const
int
block_num_y
=
(
nrows
+
GGML_CUDA_MMV_Y
-
1
)
/
GGML_CUDA_MMV_Y
;
const
dim3
block_nums
(
block_num_y
,
1
,
1
);
const
dim3
block_nums
(
block_num_y
,
nvecs
,
1
);
const
dim3
block_dims
(
WARP_SIZE
,
GGML_CUDA_MMV_Y
,
1
);
mul_mat_vec_q
<
scalar_t
,
QK8_0
,
QI8_0
,
block_q8_0
,
VDR_Q8_0_Q8_1_MMVQ
,
vec_dot_q8_0_q8_1
>
<<<
block_nums
,
block_dims
,
0
,
stream
>>>
(
vx
,
vy
,
dst
,
ncols
,
nrows
);
<<<
block_nums
,
block_dims
,
0
,
stream
>>>
(
vx
,
vy
,
dst
,
ncols
,
nrows
,
nvecs
);
}
template
<
typename
scalar_t
>
static
void
mul_mat_vec_q2_K_q8_1_cuda
(
const
void
*
vx
,
const
void
*
vy
,
scalar_t
*
dst
,
const
int
ncols
,
const
int
nrows
,
cudaStream_t
stream
)
{
static
void
mul_mat_vec_q2_K_q8_1_cuda
(
const
void
*
vx
,
const
void
*
vy
,
scalar_t
*
dst
,
const
int
ncols
,
const
int
nrows
,
const
int
nvecs
,
cudaStream_t
stream
)
{
const
int
block_num_y
=
(
nrows
+
GGML_CUDA_MMV_Y
-
1
)
/
GGML_CUDA_MMV_Y
;
const
dim3
block_nums
(
block_num_y
,
1
,
1
);
const
dim3
block_nums
(
block_num_y
,
nvecs
,
1
);
const
dim3
block_dims
(
WARP_SIZE
,
GGML_CUDA_MMV_Y
,
1
);
mul_mat_vec_q
<
scalar_t
,
QK_K
,
QI2_K
,
block_q2_K
,
VDR_Q2_K_Q8_1_MMVQ
,
vec_dot_q2_K_q8_1
>
<<<
block_nums
,
block_dims
,
0
,
stream
>>>
(
vx
,
vy
,
dst
,
ncols
,
nrows
);
<<<
block_nums
,
block_dims
,
0
,
stream
>>>
(
vx
,
vy
,
dst
,
ncols
,
nrows
,
nvecs
);
}
template
<
typename
scalar_t
>
static
void
mul_mat_vec_q3_K_q8_1_cuda
(
const
void
*
vx
,
const
void
*
vy
,
scalar_t
*
dst
,
const
int
ncols
,
const
int
nrows
,
cudaStream_t
stream
)
{
static
void
mul_mat_vec_q3_K_q8_1_cuda
(
const
void
*
vx
,
const
void
*
vy
,
scalar_t
*
dst
,
const
int
ncols
,
const
int
nrows
,
const
int
nvecs
,
cudaStream_t
stream
)
{
const
int
block_num_y
=
(
nrows
+
GGML_CUDA_MMV_Y
-
1
)
/
GGML_CUDA_MMV_Y
;
const
dim3
block_nums
(
block_num_y
,
1
,
1
);
const
dim3
block_nums
(
block_num_y
,
nvecs
,
1
);
const
dim3
block_dims
(
WARP_SIZE
,
GGML_CUDA_MMV_Y
,
1
);
mul_mat_vec_q
<
scalar_t
,
QK_K
,
QI3_K
,
block_q3_K
,
VDR_Q3_K_Q8_1_MMVQ
,
vec_dot_q3_K_q8_1
>
<<<
block_nums
,
block_dims
,
0
,
stream
>>>
(
vx
,
vy
,
dst
,
ncols
,
nrows
);
<<<
block_nums
,
block_dims
,
0
,
stream
>>>
(
vx
,
vy
,
dst
,
ncols
,
nrows
,
nvecs
);
}
template
<
typename
scalar_t
>
static
void
mul_mat_vec_q4_K_q8_1_cuda
(
const
void
*
vx
,
const
void
*
vy
,
scalar_t
*
dst
,
const
int
ncols
,
const
int
nrows
,
cudaStream_t
stream
)
{
static
void
mul_mat_vec_q4_K_q8_1_cuda
(
const
void
*
vx
,
const
void
*
vy
,
scalar_t
*
dst
,
const
int
ncols
,
const
int
nrows
,
const
int
nvecs
,
cudaStream_t
stream
)
{
const
int
block_num_y
=
(
nrows
+
GGML_CUDA_MMV_Y
-
1
)
/
GGML_CUDA_MMV_Y
;
const
dim3
block_nums
(
block_num_y
,
1
,
1
);
const
dim3
block_nums
(
block_num_y
,
nvecs
,
1
);
const
dim3
block_dims
(
WARP_SIZE
,
GGML_CUDA_MMV_Y
,
1
);
mul_mat_vec_q
<
scalar_t
,
QK_K
,
QI4_K
,
block_q4_K
,
VDR_Q4_K_Q8_1_MMVQ
,
vec_dot_q4_K_q8_1
>
<<<
block_nums
,
block_dims
,
0
,
stream
>>>
(
vx
,
vy
,
dst
,
ncols
,
nrows
);
<<<
block_nums
,
block_dims
,
0
,
stream
>>>
(
vx
,
vy
,
dst
,
ncols
,
nrows
,
nvecs
);
}
template
<
typename
scalar_t
>
static
void
mul_mat_vec_q5_K_q8_1_cuda
(
const
void
*
vx
,
const
void
*
vy
,
scalar_t
*
dst
,
const
int
ncols
,
const
int
nrows
,
cudaStream_t
stream
)
{
static
void
mul_mat_vec_q5_K_q8_1_cuda
(
const
void
*
vx
,
const
void
*
vy
,
scalar_t
*
dst
,
const
int
ncols
,
const
int
nrows
,
const
int
nvecs
,
cudaStream_t
stream
)
{
const
int
block_num_y
=
(
nrows
+
GGML_CUDA_MMV_Y
-
1
)
/
GGML_CUDA_MMV_Y
;
const
dim3
block_nums
(
block_num_y
,
1
,
1
);
const
dim3
block_nums
(
block_num_y
,
nvecs
,
1
);
const
dim3
block_dims
(
WARP_SIZE
,
GGML_CUDA_MMV_Y
,
1
);
mul_mat_vec_q
<
scalar_t
,
QK_K
,
QI5_K
,
block_q5_K
,
VDR_Q5_K_Q8_1_MMVQ
,
vec_dot_q5_K_q8_1
>
<<<
block_nums
,
block_dims
,
0
,
stream
>>>
(
vx
,
vy
,
dst
,
ncols
,
nrows
);
<<<
block_nums
,
block_dims
,
0
,
stream
>>>
(
vx
,
vy
,
dst
,
ncols
,
nrows
,
nvecs
);
}
template
<
typename
scalar_t
>
static
void
mul_mat_vec_q6_K_q8_1_cuda
(
const
void
*
vx
,
const
void
*
vy
,
scalar_t
*
dst
,
const
int
ncols
,
const
int
nrows
,
cudaStream_t
stream
)
{
static
void
mul_mat_vec_q6_K_q8_1_cuda
(
const
void
*
vx
,
const
void
*
vy
,
scalar_t
*
dst
,
const
int
ncols
,
const
int
nrows
,
const
int
nvecs
,
cudaStream_t
stream
)
{
const
int
block_num_y
=
(
nrows
+
GGML_CUDA_MMV_Y
-
1
)
/
GGML_CUDA_MMV_Y
;
const
dim3
block_nums
(
block_num_y
,
1
,
1
);
const
dim3
block_nums
(
block_num_y
,
nvecs
,
1
);
const
dim3
block_dims
(
WARP_SIZE
,
GGML_CUDA_MMV_Y
,
1
);
mul_mat_vec_q
<
scalar_t
,
QK_K
,
QI6_K
,
block_q6_K
,
VDR_Q6_K_Q8_1_MMVQ
,
vec_dot_q6_K_q8_1
>
<<<
block_nums
,
block_dims
,
0
,
stream
>>>
(
vx
,
vy
,
dst
,
ncols
,
nrows
);
<<<
block_nums
,
block_dims
,
0
,
stream
>>>
(
vx
,
vy
,
dst
,
ncols
,
nrows
,
nvecs
);
}
template
<
typename
scalar_t
>
static
void
mul_mat_vec_iq2_xxs_q8_1_cuda
(
const
void
*
vx
,
const
void
*
vy
,
scalar_t
*
dst
,
const
int
ncols
,
const
int
nrows
,
cudaStream_t
stream
)
{
static
void
mul_mat_vec_iq2_xxs_q8_1_cuda
(
const
void
*
vx
,
const
void
*
vy
,
scalar_t
*
dst
,
const
int
ncols
,
const
int
nrows
,
const
int
nvecs
,
cudaStream_t
stream
)
{
const
int
block_num_y
=
(
nrows
+
GGML_CUDA_MMV_Y
-
1
)
/
GGML_CUDA_MMV_Y
;
const
dim3
block_nums
(
block_num_y
,
1
,
1
);
const
dim3
block_nums
(
block_num_y
,
nvecs
,
1
);
const
dim3
block_dims
(
WARP_SIZE
,
GGML_CUDA_MMV_Y
,
1
);
mul_mat_vec_q
<
scalar_t
,
QK_K
,
QI2_XXS
,
block_iq2_xxs
,
1
,
vec_dot_iq2_xxs_q8_1
>
<<<
block_nums
,
block_dims
,
0
,
stream
>>>
(
vx
,
vy
,
dst
,
ncols
,
nrows
);
<<<
block_nums
,
block_dims
,
0
,
stream
>>>
(
vx
,
vy
,
dst
,
ncols
,
nrows
,
nvecs
);
}
template
<
typename
scalar_t
>
static
void
mul_mat_vec_iq2_xs_q8_1_cuda
(
const
void
*
vx
,
const
void
*
vy
,
scalar_t
*
dst
,
const
int
ncols
,
const
int
nrows
,
cudaStream_t
stream
)
{
static
void
mul_mat_vec_iq2_xs_q8_1_cuda
(
const
void
*
vx
,
const
void
*
vy
,
scalar_t
*
dst
,
const
int
ncols
,
const
int
nrows
,
const
int
nvecs
,
cudaStream_t
stream
)
{
const
int
block_num_y
=
(
nrows
+
GGML_CUDA_MMV_Y
-
1
)
/
GGML_CUDA_MMV_Y
;
const
dim3
block_nums
(
block_num_y
,
1
,
1
);
const
dim3
block_nums
(
block_num_y
,
nvecs
,
1
);
const
dim3
block_dims
(
WARP_SIZE
,
GGML_CUDA_MMV_Y
,
1
);
mul_mat_vec_q
<
scalar_t
,
QK_K
,
QI2_XS
,
block_iq2_xs
,
1
,
vec_dot_iq2_xs_q8_1
>
<<<
block_nums
,
block_dims
,
0
,
stream
>>>
(
vx
,
vy
,
dst
,
ncols
,
nrows
);
<<<
block_nums
,
block_dims
,
0
,
stream
>>>
(
vx
,
vy
,
dst
,
ncols
,
nrows
,
nvecs
);
}
template
<
typename
scalar_t
>
static
void
mul_mat_vec_iq2_s_q8_1_cuda
(
const
void
*
vx
,
const
void
*
vy
,
scalar_t
*
dst
,
const
int
ncols
,
const
int
nrows
,
cudaStream_t
stream
)
{
static
void
mul_mat_vec_iq2_s_q8_1_cuda
(
const
void
*
vx
,
const
void
*
vy
,
scalar_t
*
dst
,
const
int
ncols
,
const
int
nrows
,
const
int
nvecs
,
cudaStream_t
stream
)
{
const
int
block_num_y
=
(
nrows
+
GGML_CUDA_MMV_Y
-
1
)
/
GGML_CUDA_MMV_Y
;
const
dim3
block_nums
(
block_num_y
,
1
,
1
);
const
dim3
block_nums
(
block_num_y
,
nvecs
,
1
);
const
dim3
block_dims
(
WARP_SIZE
,
GGML_CUDA_MMV_Y
,
1
);
mul_mat_vec_q
<
scalar_t
,
QK_K
,
QI2_S
,
block_iq2_s
,
1
,
vec_dot_iq2_s_q8_1
>
<<<
block_nums
,
block_dims
,
0
,
stream
>>>
(
vx
,
vy
,
dst
,
ncols
,
nrows
);
<<<
block_nums
,
block_dims
,
0
,
stream
>>>
(
vx
,
vy
,
dst
,
ncols
,
nrows
,
nvecs
);
}
template
<
typename
scalar_t
>
static
void
mul_mat_vec_iq3_xxs_q8_1_cuda
(
const
void
*
vx
,
const
void
*
vy
,
scalar_t
*
dst
,
const
int
ncols
,
const
int
nrows
,
cudaStream_t
stream
)
{
static
void
mul_mat_vec_iq3_xxs_q8_1_cuda
(
const
void
*
vx
,
const
void
*
vy
,
scalar_t
*
dst
,
const
int
ncols
,
const
int
nrows
,
const
int
nvecs
,
cudaStream_t
stream
)
{
const
int
block_num_y
=
(
nrows
+
GGML_CUDA_MMV_Y
-
1
)
/
GGML_CUDA_MMV_Y
;
const
dim3
block_nums
(
block_num_y
,
1
,
1
);
const
dim3
block_nums
(
block_num_y
,
nvecs
,
1
);
const
dim3
block_dims
(
WARP_SIZE
,
GGML_CUDA_MMV_Y
,
1
);
mul_mat_vec_q
<
scalar_t
,
QK_K
,
QI3_XXS
,
block_iq3_xxs
,
1
,
vec_dot_iq3_xxs_q8_1
>
<<<
block_nums
,
block_dims
,
0
,
stream
>>>
(
vx
,
vy
,
dst
,
ncols
,
nrows
);
<<<
block_nums
,
block_dims
,
0
,
stream
>>>
(
vx
,
vy
,
dst
,
ncols
,
nrows
,
nvecs
);
}
template
<
typename
scalar_t
>
static
void
mul_mat_vec_iq1_s_q8_1_cuda
(
const
void
*
vx
,
const
void
*
vy
,
scalar_t
*
dst
,
const
int
ncols
,
const
int
nrows
,
cudaStream_t
stream
)
{
static
void
mul_mat_vec_iq1_s_q8_1_cuda
(
const
void
*
vx
,
const
void
*
vy
,
scalar_t
*
dst
,
const
int
ncols
,
const
int
nrows
,
const
int
nvecs
,
cudaStream_t
stream
)
{
const
int
block_num_y
=
(
nrows
+
GGML_CUDA_MMV_Y
-
1
)
/
GGML_CUDA_MMV_Y
;
const
dim3
block_nums
(
block_num_y
,
1
,
1
);
const
dim3
block_nums
(
block_num_y
,
nvecs
,
1
);
const
dim3
block_dims
(
WARP_SIZE
,
GGML_CUDA_MMV_Y
,
1
);
mul_mat_vec_q
<
scalar_t
,
QK_K
,
QI1_S
,
block_iq1_s
,
1
,
vec_dot_iq1_s_q8_1
>
<<<
block_nums
,
block_dims
,
0
,
stream
>>>
(
vx
,
vy
,
dst
,
ncols
,
nrows
);
<<<
block_nums
,
block_dims
,
0
,
stream
>>>
(
vx
,
vy
,
dst
,
ncols
,
nrows
,
nvecs
);
}
template
<
typename
scalar_t
>
static
void
mul_mat_vec_iq1_m_q8_1_cuda
(
const
void
*
vx
,
const
void
*
vy
,
scalar_t
*
dst
,
const
int
ncols
,
const
int
nrows
,
cudaStream_t
stream
)
{
static
void
mul_mat_vec_iq1_m_q8_1_cuda
(
const
void
*
vx
,
const
void
*
vy
,
scalar_t
*
dst
,
const
int
ncols
,
const
int
nrows
,
const
int
nvecs
,
cudaStream_t
stream
)
{
const
int
block_num_y
=
(
nrows
+
GGML_CUDA_MMV_Y
-
1
)
/
GGML_CUDA_MMV_Y
;
const
dim3
block_nums
(
block_num_y
,
1
,
1
);
const
dim3
block_nums
(
block_num_y
,
nvecs
,
1
);
const
dim3
block_dims
(
WARP_SIZE
,
GGML_CUDA_MMV_Y
,
1
);
mul_mat_vec_q
<
scalar_t
,
QK_K
,
QI1_M
,
block_iq1_m
,
1
,
vec_dot_iq1_m_q8_1
>
<<<
block_nums
,
block_dims
,
0
,
stream
>>>
(
vx
,
vy
,
dst
,
ncols
,
nrows
);
<<<
block_nums
,
block_dims
,
0
,
stream
>>>
(
vx
,
vy
,
dst
,
ncols
,
nrows
,
nvecs
);
}
template
<
typename
scalar_t
>
static
void
mul_mat_vec_iq4_nl_q8_1_cuda
(
const
void
*
vx
,
const
void
*
vy
,
scalar_t
*
dst
,
const
int
ncols
,
const
int
nrows
,
cudaStream_t
stream
)
{
static
void
mul_mat_vec_iq4_nl_q8_1_cuda
(
const
void
*
vx
,
const
void
*
vy
,
scalar_t
*
dst
,
const
int
ncols
,
const
int
nrows
,
const
int
nvecs
,
cudaStream_t
stream
)
{
const
int
block_num_y
=
(
nrows
+
GGML_CUDA_MMV_Y
-
1
)
/
GGML_CUDA_MMV_Y
;
const
dim3
block_nums
(
block_num_y
,
1
,
1
);
const
dim3
block_nums
(
block_num_y
,
nvecs
,
1
);
const
dim3
block_dims
(
WARP_SIZE
,
GGML_CUDA_MMV_Y
,
1
);
mul_mat_vec_q
<
scalar_t
,
QK4_NL
,
QI4_NL
,
block_iq4_nl
,
VDR_Q4_0_Q8_1_MMVQ
,
vec_dot_iq4_nl_q8_1
>
<<<
block_nums
,
block_dims
,
0
,
stream
>>>
(
vx
,
vy
,
dst
,
ncols
,
nrows
);
<<<
block_nums
,
block_dims
,
0
,
stream
>>>
(
vx
,
vy
,
dst
,
ncols
,
nrows
,
nvecs
);
}
template
<
typename
scalar_t
>
static
void
mul_mat_vec_iq4_xs_q8_1_cuda
(
const
void
*
vx
,
const
void
*
vy
,
scalar_t
*
dst
,
const
int
ncols
,
const
int
nrows
,
cudaStream_t
stream
)
{
static
void
mul_mat_vec_iq4_xs_q8_1_cuda
(
const
void
*
vx
,
const
void
*
vy
,
scalar_t
*
dst
,
const
int
ncols
,
const
int
nrows
,
const
int
nvecs
,
cudaStream_t
stream
)
{
const
int
block_num_y
=
(
nrows
+
GGML_CUDA_MMV_Y
-
1
)
/
GGML_CUDA_MMV_Y
;
const
dim3
block_nums
(
block_num_y
,
1
,
1
);
const
dim3
block_nums
(
block_num_y
,
nvecs
,
1
);
const
dim3
block_dims
(
WARP_SIZE
,
GGML_CUDA_MMV_Y
,
1
);
mul_mat_vec_q
<
scalar_t
,
QK_K
,
QI4_XS
,
block_iq4_xs
,
1
,
vec_dot_iq4_xs_q8_1
>
<<<
block_nums
,
block_dims
,
0
,
stream
>>>
(
vx
,
vy
,
dst
,
ncols
,
nrows
);
<<<
block_nums
,
block_dims
,
0
,
stream
>>>
(
vx
,
vy
,
dst
,
ncols
,
nrows
,
nvecs
);
}
template
<
typename
scalar_t
>
static
void
mul_mat_vec_iq3_s_q8_1_cuda
(
const
void
*
vx
,
const
void
*
vy
,
scalar_t
*
dst
,
const
int
ncols
,
const
int
nrows
,
cudaStream_t
stream
)
{
static
void
mul_mat_vec_iq3_s_q8_1_cuda
(
const
void
*
vx
,
const
void
*
vy
,
scalar_t
*
dst
,
const
int
ncols
,
const
int
nrows
,
const
int
nvecs
,
cudaStream_t
stream
)
{
const
int
block_num_y
=
(
nrows
+
GGML_CUDA_MMV_Y
-
1
)
/
GGML_CUDA_MMV_Y
;
const
dim3
block_nums
(
block_num_y
,
1
,
1
);
const
dim3
block_nums
(
block_num_y
,
nvecs
,
1
);
const
dim3
block_dims
(
WARP_SIZE
,
GGML_CUDA_MMV_Y
,
1
);
mul_mat_vec_q
<
scalar_t
,
QK_K
,
QI3_XS
,
block_iq3_s
,
1
,
vec_dot_iq3_s_q8_1
>
<<<
block_nums
,
block_dims
,
0
,
stream
>>>
(
vx
,
vy
,
dst
,
ncols
,
nrows
);
<<<
block_nums
,
block_dims
,
0
,
stream
>>>
(
vx
,
vy
,
dst
,
ncols
,
nrows
,
nvecs
);
}
csrc/quantization/gptq/q_gemm.cu
View file @
99324e25
...
...
@@ -210,8 +210,6 @@ __global__ void gemm_half_q_half_gptq_4bit_kernel(
auto
offset_m
=
blockIdx
.
y
*
m_count
;
auto
offset_k
=
blockIdx
.
z
*
BLOCK_KN_SIZE
;
[[
maybe_unused
]]
int
end_n
=
min
(
offset_n
+
BLOCK_KN_SIZE
*
4
,
size_n
);
[[
maybe_unused
]]
int
end_m
=
min
(
offset_m
+
m_count
,
size_m
);
int
end_k
=
min
(
offset_k
+
BLOCK_KN_SIZE
,
size_k
);
int
n
=
offset_n
+
t
*
4
;
...
...
@@ -348,8 +346,6 @@ __global__ void gemm_half_q_half_gptq_2bit_kernel(
auto
offset_m
=
blockIdx
.
y
*
m_count
;
auto
offset_k
=
blockIdx
.
z
*
BLOCK_KN_SIZE
;
[[
maybe_unused
]]
int
end_n
=
min
(
offset_n
+
BLOCK_KN_SIZE
*
4
,
size_n
);
[[
maybe_unused
]]
int
end_m
=
min
(
offset_m
+
m_count
,
size_m
);
int
end_k
=
min
(
offset_k
+
BLOCK_KN_SIZE
,
size_k
);
int
n
=
offset_n
+
t
*
4
;
...
...
@@ -469,8 +465,6 @@ __global__ void gemm_half_q_half_gptq_3bit_kernel(
auto
offset_m
=
blockIdx
.
y
*
m_count
;
auto
offset_k
=
blockIdx
.
z
*
BLOCK_KN_SIZE
;
[[
maybe_unused
]]
int
end_n
=
min
(
offset_n
+
BLOCK_KN_SIZE
*
4
,
size_n
);
[[
maybe_unused
]]
int
end_m
=
min
(
offset_m
+
m_count
,
size_m
);
int
end_k
=
min
(
offset_k
+
BLOCK_KN_SIZE
,
size_k
);
int
n
=
offset_n
+
t
*
4
;
...
...
@@ -597,8 +591,6 @@ __global__ void gemm_half_q_half_gptq_8bit_kernel(
auto
offset_m
=
blockIdx
.
y
*
m_count
;
auto
offset_k
=
blockIdx
.
z
*
BLOCK_KN_SIZE
;
[[
maybe_unused
]]
int
end_n
=
min
(
offset_n
+
BLOCK_KN_SIZE
*
4
,
size_n
);
[[
maybe_unused
]]
int
end_m
=
min
(
offset_m
+
m_count
,
size_m
);
int
end_k
=
min
(
offset_k
+
BLOCK_KN_SIZE
,
size_k
);
int
n
=
offset_n
+
t
*
4
;
...
...
csrc/quantization/gptq_marlin/marlin_template.h
View file @
99324e25
...
...
@@ -1113,8 +1113,6 @@ __global__ void Marlin(
if
constexpr
(
has_zp
&&
!
is_zp_float
)
{
if
(
is_new_zp
)
{
if
constexpr
(
group_blocks
==
-
1
)
is_first_matmul_in_slice
=
false
;
FragB
frag_zp_0
;
FragB
frag_zp_1
;
int
zp_quant_0
,
zp_quant_1
;
if
constexpr
(
w_type
.
size_bits
()
==
4
)
{
...
...
csrc/quantization/machete/machete_mainloop.cuh
View file @
99324e25
...
...
@@ -38,7 +38,6 @@
#include "cute/atom/mma_atom.hpp"
#include "cute/atom/copy_traits_sm90_tma.hpp"
#include "cute/algorithm/gemm.hpp"
#include "cute/tensor_predicate.hpp"
#include "cute/numeric/arithmetic_tuple.hpp"
#include "cutlass/pipeline/pipeline.hpp"
#include "cutlass/transform/collective/sm90_wgmma_transpose.hpp"
...
...
@@ -1003,7 +1002,7 @@ struct MacheteCollectiveMma {
static
constexpr
int
A_CPY_VEC
=
decltype
(
max_common_vector
(
tCsA
,
tCrA_load
)){};
static
constexpr
int
COVERSION_WIDTH
=
static
constexpr
int
CO
N
VERSION_WIDTH
=
std
::
min
(
A_CPY_VEC
,
int
(
size
<
0
>
(
tCrA_mma
)));
auto
load_A_to_registers
=
[
&
](
int
read_stage
)
{
...
...
@@ -1026,8 +1025,8 @@ struct MacheteCollectiveMma {
// PIPELINED MAIN LOOP
//
auto
convert_A
=
[
&
,
a_vec
=
Int
<
COVERSION_WIDTH
>
{}](
int
k_block
,
int
read_stage
)
{
auto
convert_A
=
[
&
,
a_vec
=
Int
<
CO
N
VERSION_WIDTH
>
{}](
int
k_block
,
int
read_stage
)
{
load_extra_info_to_registers
(
partitioned_extra_info
,
copy_partitions_extra_info
,
k_block
,
read_stage
);
...
...
csrc/quantization/vectorization_utils.cuh
0 → 100644
View file @
99324e25
#pragma once
#include "vectorization.cuh"
namespace
vllm
{
template
<
int
VEC_SIZE
,
typename
InT
,
typename
OutT
,
typename
ScaOp
>
struct
DefaultVecOp
{
ScaOp
scalar_op
;
__device__
__forceinline__
void
operator
()(
vec_n_t
<
OutT
,
VEC_SIZE
>&
dst
,
const
vec_n_t
<
InT
,
VEC_SIZE
>&
src
)
const
{
#pragma unroll
for
(
int
i
=
0
;
i
<
VEC_SIZE
;
++
i
)
{
scalar_op
(
dst
.
val
[
i
],
src
.
val
[
i
]);
}
}
};
template
<
int
VEC_SIZE
,
typename
InT
,
typename
OutT
,
typename
VecOp
,
typename
ScaOp
>
__device__
inline
void
vectorize_with_alignment
(
const
InT
*
in
,
OutT
*
out
,
int
len
,
int
tid
,
int
stride
,
VecOp
&&
vec_op
,
// vec_n_t<InT,16> -> vec_n_t<OutT,16>
ScaOp
&&
scalar_op
)
{
// InT -> OutT
static_assert
(
VEC_SIZE
>
0
&&
(
VEC_SIZE
&
(
VEC_SIZE
-
1
))
==
0
,
"VEC_SIZE must be a positive power-of-two"
);
constexpr
int
WIDTH
=
VEC_SIZE
*
sizeof
(
InT
);
// eg: 64 B
uintptr_t
addr
=
reinterpret_cast
<
uintptr_t
>
(
in
);
// fast path when the whole region is already aligned
// Note: currently the output is guaranteed to be same as the input, so we
// don't check it here, comments here just for future reference.
bool
can_vec
=
((
addr
&
(
WIDTH
-
1
))
==
0
)
&&
((
len
&
(
VEC_SIZE
-
1
))
==
0
);
if
(
can_vec
)
{
int
num_vec
=
len
/
VEC_SIZE
;
using
vin_t
=
vec_n_t
<
InT
,
VEC_SIZE
>
;
using
vout_t
=
vec_n_t
<
OutT
,
VEC_SIZE
>
;
auto
*
v_in
=
reinterpret_cast
<
const
vin_t
*>
(
in
);
auto
*
v_out
=
reinterpret_cast
<
vout_t
*>
(
out
);
for
(
int
i
=
tid
;
i
<
num_vec
;
i
+=
stride
)
{
vout_t
tmp
;
vec_op
(
tmp
,
v_in
[
i
]);
v_out
[
i
]
=
tmp
;
}
return
;
}
int
misalignment_offset
=
addr
&
(
WIDTH
-
1
);
// addr % 64
int
alignment_bytes
=
WIDTH
-
misalignment_offset
;
// 64 - (addr % 64)
int
prefix_elems
=
alignment_bytes
&
(
WIDTH
-
1
);
// handle 64
prefix_elems
/=
sizeof
(
InT
);
prefix_elems
=
min
(
prefix_elems
,
len
);
// 0 ≤ prefix < 16
// 1. prefill the when it is unsafe to vectorize
for
(
int
i
=
tid
;
i
<
prefix_elems
;
i
+=
stride
)
{
scalar_op
(
out
[
i
],
in
[
i
]);
}
in
+=
prefix_elems
;
out
+=
prefix_elems
;
len
-=
prefix_elems
;
int
num_vec
=
len
/
VEC_SIZE
;
using
vin_t
=
vec_n_t
<
InT
,
VEC_SIZE
>
;
using
vout_t
=
vec_n_t
<
OutT
,
VEC_SIZE
>
;
auto
*
v_in
=
reinterpret_cast
<
const
vin_t
*>
(
in
);
auto
*
v_out
=
reinterpret_cast
<
vout_t
*>
(
out
);
// 2. vectorize the main part
for
(
int
i
=
tid
;
i
<
num_vec
;
i
+=
stride
)
{
vout_t
tmp
;
vec_op
(
tmp
,
v_in
[
i
]);
v_out
[
i
]
=
tmp
;
}
// 3. handle the tail
int
tail_start
=
num_vec
*
VEC_SIZE
;
for
(
int
i
=
tid
+
tail_start
;
i
<
len
;
i
+=
stride
)
{
scalar_op
(
out
[
i
],
in
[
i
]);
}
}
template
<
int
VEC_SIZE
,
typename
InT
,
typename
OutT
,
typename
ScaOp
>
__device__
__forceinline__
void
vectorize_with_alignment
(
const
InT
*
in
,
OutT
*
out
,
int
len
,
int
tid
,
int
stride
,
ScaOp
&&
scalar_op
)
{
using
Vec
=
DefaultVecOp
<
VEC_SIZE
,
InT
,
OutT
,
std
::
decay_t
<
ScaOp
>>
;
vectorize_with_alignment
<
VEC_SIZE
>
(
in
,
out
,
len
,
tid
,
stride
,
Vec
{
scalar_op
},
std
::
forward
<
ScaOp
>
(
scalar_op
));
}
template
<
int
VEC_SIZE
,
typename
InT
,
typename
ScaOp
>
struct
DefaultReadVecOp
{
ScaOp
scalar_op
;
__device__
__forceinline__
void
operator
()(
const
vec_n_t
<
InT
,
VEC_SIZE
>&
src
)
const
{
#pragma unroll
for
(
int
i
=
0
;
i
<
VEC_SIZE
;
++
i
)
{
scalar_op
(
src
.
val
[
i
]);
}
}
};
// read-only version: iterate over the input with alignment guarantees
template
<
int
VEC_SIZE
,
typename
InT
,
typename
VecOp
,
typename
ScaOp
>
__device__
inline
void
vectorize_read_with_alignment
(
const
InT
*
in
,
int
len
,
int
tid
,
int
stride
,
VecOp
&&
vec_op
,
ScaOp
&&
scalar_op
)
{
static_assert
(
VEC_SIZE
>
0
&&
(
VEC_SIZE
&
(
VEC_SIZE
-
1
))
==
0
,
"VEC_SIZE must be a positive power-of-two"
);
constexpr
int
WIDTH
=
VEC_SIZE
*
sizeof
(
InT
);
uintptr_t
addr
=
reinterpret_cast
<
uintptr_t
>
(
in
);
// fast path when the whole region is already aligned
bool
can_vec
=
((
addr
&
(
WIDTH
-
1
))
==
0
)
&&
((
len
&
(
VEC_SIZE
-
1
))
==
0
);
if
(
can_vec
)
{
int
num_vec
=
len
/
VEC_SIZE
;
using
vin_t
=
vec_n_t
<
InT
,
VEC_SIZE
>
;
auto
*
v_in
=
reinterpret_cast
<
const
vin_t
*>
(
in
);
for
(
int
i
=
tid
;
i
<
num_vec
;
i
+=
stride
)
{
vec_op
(
v_in
[
i
]);
}
return
;
}
int
misalignment_offset
=
addr
&
(
WIDTH
-
1
);
int
alignment_bytes
=
WIDTH
-
misalignment_offset
;
int
prefix_elems
=
alignment_bytes
&
(
WIDTH
-
1
);
prefix_elems
/=
sizeof
(
InT
);
prefix_elems
=
min
(
prefix_elems
,
len
);
// 1. handle the possibly unaligned prefix with scalar access.
for
(
int
i
=
tid
;
i
<
prefix_elems
;
i
+=
stride
)
{
scalar_op
(
in
[
i
]);
}
in
+=
prefix_elems
;
len
-=
prefix_elems
;
int
num_vec
=
len
/
VEC_SIZE
;
using
vin_t
=
vec_n_t
<
InT
,
VEC_SIZE
>
;
auto
*
v_in
=
reinterpret_cast
<
const
vin_t
*>
(
in
);
// 2. vectorized traversal of the main aligned region.
for
(
int
i
=
tid
;
i
<
num_vec
;
i
+=
stride
)
{
vec_op
(
v_in
[
i
]);
}
// 3. handle remaining tail elements.
int
tail_start
=
num_vec
*
VEC_SIZE
;
for
(
int
i
=
tid
+
tail_start
;
i
<
len
;
i
+=
stride
)
{
scalar_op
(
in
[
i
]);
}
}
// overload that requires only a scalar_op
template
<
int
VEC_SIZE
,
typename
InT
,
typename
ScaOp
>
__device__
__forceinline__
void
vectorize_read_with_alignment
(
const
InT
*
in
,
int
len
,
int
tid
,
int
stride
,
ScaOp
&&
scalar_op
)
{
using
Vec
=
DefaultReadVecOp
<
VEC_SIZE
,
InT
,
std
::
decay_t
<
ScaOp
>>
;
vectorize_read_with_alignment
<
VEC_SIZE
>
(
in
,
len
,
tid
,
stride
,
Vec
{
scalar_op
},
std
::
forward
<
ScaOp
>
(
scalar_op
));
}
}
// namespace vllm
Prev
1
2
3
4
5
6
7
8
9
…
24
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment