Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
afd0da21
Commit
afd0da21
authored
Feb 03, 2025
by
zhuwenwen
Browse files
Merge tag 'v0.7.1' into v0.7.1-dev
parents
1a11f127
4f4d427a
Changes
587
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
497 additions
and
171 deletions
+497
-171
csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm90_fp8.cu
...tization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm90_fp8.cu
+24
-0
csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm90_fp8_dispatch.cuh
...utlass_w8a8/c3x/scaled_mm_blockwise_sm90_fp8_dispatch.cuh
+168
-0
csrc/quantization/cutlass_w8a8/c3x/scaled_mm_kernels.hpp
csrc/quantization/cutlass_w8a8/c3x/scaled_mm_kernels.hpp
+33
-0
csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm90_fp8.cu
csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm90_fp8.cu
+24
-0
csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm90_fp8_dispatch.cuh
...tization/cutlass_w8a8/c3x/scaled_mm_sm90_fp8_dispatch.cuh
+25
-1
csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm90_int8.cu
csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm90_int8.cu
+24
-0
csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm90_int8_dispatch.cuh
...ization/cutlass_w8a8/c3x/scaled_mm_sm90_int8_dispatch.cuh
+24
-1
csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cu
csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cu
+9
-9
csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu
csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu
+56
-60
csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu
csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu
+29
-19
csrc/quantization/gptq_marlin/gptq_marlin.cu
csrc/quantization/gptq_marlin/gptq_marlin.cu
+29
-27
csrc/quantization/machete/generate.py
csrc/quantization/machete/generate.py
+5
-5
csrc/quantization/machete/machete_collective_builder.cuh
csrc/quantization/machete/machete_collective_builder.cuh
+4
-6
csrc/quantization/machete/machete_mainloop.cuh
csrc/quantization/machete/machete_mainloop.cuh
+8
-7
csrc/quantization/machete/machete_mm_kernel.cuh
csrc/quantization/machete/machete_mm_kernel.cuh
+5
-5
csrc/quantization/machete/machete_mm_launcher.cuh
csrc/quantization/machete/machete_mm_launcher.cuh
+12
-12
csrc/quantization/machete/machete_prepack_launcher.cuh
csrc/quantization/machete/machete_prepack_launcher.cuh
+1
-1
csrc/quantization/machete/machete_prepacked_layout.cuh
csrc/quantization/machete/machete_prepacked_layout.cuh
+2
-3
csrc/quantization/machete/machete_pytorch.cu
csrc/quantization/machete/machete_pytorch.cu
+13
-13
csrc/quantization/marlin/dense/marlin_cuda_kernel.cu
csrc/quantization/marlin/dense/marlin_cuda_kernel.cu
+2
-2
No files found.
Too many changes to show.
To preserve performance only
587 of 587+
files are displayed.
Plain diff
Email patch
csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm90_fp8.cu
0 → 100644
View file @
afd0da21
#include "scaled_mm_kernels.hpp"
#include "scaled_mm_blockwise_sm90_fp8_dispatch.cuh"
#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp"
namespace
vllm
{
void
cutlass_scaled_mm_blockwise_sm90_fp8
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
)
{
if
(
out
.
dtype
()
==
torch
::
kBFloat16
)
{
cutlass_gemm_blockwise_sm90_fp8_dispatch
<
cutlass
::
bfloat16_t
>
(
out
,
a
,
b
,
a_scales
,
b_scales
);
}
else
{
TORCH_CHECK
(
out
.
dtype
()
==
torch
::
kFloat16
);
cutlass_gemm_blockwise_sm90_fp8_dispatch
<
cutlass
::
half_t
>
(
out
,
a
,
b
,
a_scales
,
b_scales
);
}
}
}
// namespace vllm
\ No newline at end of file
csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm90_fp8_dispatch.cuh
0 → 100644
View file @
afd0da21
#pragma once
#include "cutlass/cutlass.h"
#include "cutlass/numeric_types.h"
#include "cute/tensor.hpp"
#include "cutlass/tensor_ref.h"
#include "cutlass/gemm/dispatch_policy.hpp"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include "cutlass/gemm/kernel/gemm_universal.hpp"
#include "cutlass/gemm/kernel/tile_scheduler_params.h"
#include "cutlass/epilogue/dispatch_policy.hpp"
#include "cutlass/epilogue/collective/collective_builder.hpp"
#include "cutlass_extensions/gemm/dispatch_policy.hpp"
#include "cutlass_extensions/gemm/collective/collective_builder.hpp"
#include "cutlass_gemm_caller.cuh"
namespace
vllm
{
using
namespace
cute
;
template
<
typename
OutType
,
int
GroupSizeM_
,
int
GroupSizeN_
,
int
GroupSizeK_
,
int
TileSizeM_
=
128
,
class
ClusterShape
=
Shape
<
_1
,
_2
,
_1
>
>
struct
cutlass_3x_gemm_fp8_blockwise
{
using
GroupSizeM
=
Int
<
GroupSizeM_
>
;
using
GroupSizeN
=
Int
<
GroupSizeN_
>
;
using
GroupSizeK
=
Int
<
GroupSizeK_
>
;
using
TileSizeM
=
Int
<
TileSizeM_
>
;
static_assert
(
TileSizeM_
%
GroupSizeM_
==
0
,
"TileSizeM must be a multiple of GroupSizeM"
);
using
ElementAB
=
cutlass
::
float_e4m3_t
;
using
ElementA
=
ElementAB
;
using
LayoutA
=
cutlass
::
layout
::
RowMajor
;
static
constexpr
int
AlignmentA
=
128
/
cutlass
::
sizeof_bits
<
ElementA
>::
value
;
using
ElementB
=
ElementAB
;
using
LayoutB
=
cutlass
::
layout
::
ColumnMajor
;
static
constexpr
int
AlignmentB
=
128
/
cutlass
::
sizeof_bits
<
ElementB
>::
value
;
using
ElementD
=
OutType
;
using
StrideD
=
Stride
<
int64_t
,
Int
<
1
>
,
Int
<
0
>>
;
static
constexpr
int
AlignmentD
=
128
/
cutlass
::
sizeof_bits
<
ElementD
>::
value
;
using
ElementC
=
void
;
using
StrideC
=
StrideD
;
static
constexpr
int
AlignmentC
=
AlignmentD
;
using
ElementAccumulator
=
float
;
using
ElementBlockScale
=
float
;
using
ElementCompute
=
float
;
using
ArchTag
=
cutlass
::
arch
::
Sm90
;
using
OperatorClass
=
cutlass
::
arch
::
OpClassTensorOp
;
using
TileShape
=
Shape
<
TileSizeM
,
GroupSizeN
,
GroupSizeK
>
;
using
KernelSchedule
=
cutlass
::
gemm
::
KernelTmaWarpSpecializedCooperativeFP8BlockScaledSubGroupMAccum
<
GroupSizeM_
>
;
using
EpilogueSchedule
=
cutlass
::
epilogue
::
TmaWarpSpecializedCooperative
;
using
EpilogueTileType
=
cutlass
::
epilogue
::
collective
::
EpilogueTileAuto
;
using
StoreEpilogueCompute
=
typename
cutlass
::
epilogue
::
fusion
::
Sm90EVT
<
cutlass
::
epilogue
::
fusion
::
Sm90AccFetch
>
;
using
CollectiveEpilogue
=
typename
cutlass
::
epilogue
::
collective
::
CollectiveBuilder
<
ArchTag
,
OperatorClass
,
TileShape
,
ClusterShape
,
EpilogueTileType
,
ElementAccumulator
,
ElementCompute
,
ElementC
,
StrideC
,
AlignmentC
,
ElementD
,
StrideD
,
AlignmentD
,
EpilogueSchedule
,
StoreEpilogueCompute
>::
CollectiveOp
;
using
CollectiveMainloop
=
typename
cutlass
::
gemm
::
collective
::
CollectiveBuilder
<
ArchTag
,
OperatorClass
,
ElementA
,
LayoutA
,
AlignmentA
,
ElementB
,
LayoutB
,
AlignmentB
,
ElementAccumulator
,
TileShape
,
ClusterShape
,
cutlass
::
gemm
::
collective
::
StageCountAutoCarveout
<
static_cast
<
int
>
(
sizeof
(
typename
CollectiveEpilogue
::
SharedStorage
))
>
,
KernelSchedule
>::
CollectiveOp
;
using
KernelType
=
enable_sm90_or_later
<
cutlass
::
gemm
::
kernel
::
GemmUniversal
<
Shape
<
int
,
int
,
int
,
int
>
,
CollectiveMainloop
,
CollectiveEpilogue
,
cutlass
::
gemm
::
PersistentScheduler
>>
;
struct
GemmKernel
:
public
KernelType
{};
using
StrideA
=
typename
GemmKernel
::
StrideA
;
using
StrideB
=
typename
GemmKernel
::
StrideB
;
};
template
<
typename
Gemm
>
void
cutlass_gemm_caller_blockwise
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
)
{
using
GemmKernel
=
typename
Gemm
::
GemmKernel
;
using
ElementAB
=
typename
Gemm
::
ElementAB
;
using
ElementD
=
typename
Gemm
::
ElementD
;
auto
prob_shape
=
c3x
::
get_problem_shape
(
a
,
b
);
int32_t
m
=
get
<
0
>
(
prob_shape
),
n
=
get
<
1
>
(
prob_shape
),
k
=
get
<
2
>
(
prob_shape
);
int64_t
lda
=
a
.
stride
(
0
);
int64_t
ldb
=
b
.
stride
(
1
);
int64_t
ldc
=
out
.
stride
(
0
);
using
StrideA
=
Stride
<
int64_t
,
Int
<
1
>
,
int64_t
>
;
using
StrideB
=
Stride
<
int64_t
,
Int
<
1
>
,
int64_t
>
;
using
StrideC
=
typename
Gemm
::
StrideC
;
StrideA
a_stride
{
lda
,
Int
<
1
>
{},
0
};
StrideB
b_stride
{
ldb
,
Int
<
1
>
{},
0
};
StrideC
c_stride
{
ldc
,
Int
<
1
>
{},
Int
<
0
>
{}};
auto
a_ptr
=
static_cast
<
ElementAB
*>
(
a
.
data_ptr
());
auto
b_ptr
=
static_cast
<
ElementAB
*>
(
b
.
data_ptr
());
auto
a_scales_ptr
=
static_cast
<
float
*>
(
a_scales
.
data_ptr
());
auto
b_scales_ptr
=
static_cast
<
float
*>
(
b_scales
.
data_ptr
());
// Check is the t is contiguous and is 1D or 2D with one of the dimensions
// being 1 (i.e. a row or column vector)
auto
is_contiguous_vector
=
[](
const
torch
::
Tensor
&
t
)
{
auto
t_sizes
=
t
.
sizes
();
return
t
.
is_contiguous
()
&&
(
t
.
dim
()
==
1
||
(
t
.
dim
()
==
2
&&
*
std
::
min_element
(
t_sizes
.
begin
(),
t_sizes
.
end
())
==
1
));
};
// TODO(lucas): lets clean-up the kernel so that we pass in Strides so
// we don't have to deal with enforcing implicit layouts
TORCH_CHECK
(
a_scales
.
size
(
0
)
==
m
/
Gemm
::
GroupSizeM
::
value
);
TORCH_CHECK
(
a_scales
.
size
(
1
)
==
k
/
Gemm
::
GroupSizeK
::
value
);
TORCH_CHECK
(
a_scales
.
stride
(
0
)
==
1
||
is_contiguous_vector
(
a_scales
),
"a_scales must be M major"
);
TORCH_CHECK
(
b_scales
.
size
(
0
)
==
k
/
Gemm
::
GroupSizeK
::
value
);
TORCH_CHECK
(
b_scales
.
size
(
1
)
==
n
/
Gemm
::
GroupSizeN
::
value
);
TORCH_CHECK
(
b_scales
.
stride
(
0
)
==
1
||
is_contiguous_vector
(
b_scales
),
"b_scales must be K major"
);
typename
GemmKernel
::
MainloopArguments
mainloop_args
{
a_ptr
,
a_stride
,
b_ptr
,
b_stride
,
a_scales_ptr
,
b_scales_ptr
};
auto
c_ptr
=
static_cast
<
ElementD
*>
(
out
.
data_ptr
());
typename
GemmKernel
::
EpilogueArguments
epilogue_args
{
{},
c_ptr
,
c_stride
,
c_ptr
,
c_stride
};
c3x
::
cutlass_gemm_caller
<
GemmKernel
>
(
a
.
device
(),
prob_shape
,
mainloop_args
,
epilogue_args
);
}
template
<
typename
OutType
>
void
cutlass_gemm_blockwise_sm90_fp8_dispatch
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
)
{
cutlass_gemm_caller_blockwise
<
cutlass_3x_gemm_fp8_blockwise
<
OutType
,
1
,
128
,
128
>>
(
out
,
a
,
b
,
a_scales
,
b_scales
);
}
}
// namespace vllm
\ No newline at end of file
csrc/quantization/cutlass_w8a8/c3x/scaled_mm_kernels.hpp
0 → 100644
View file @
afd0da21
#pragma once
#include <torch/all.h>
namespace
vllm
{
void
cutlass_scaled_mm_sm90_fp8
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
,
std
::
optional
<
torch
::
Tensor
>
const
&
bias
);
void
cutlass_scaled_mm_sm90_int8
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
,
std
::
optional
<
torch
::
Tensor
>
const
&
bias
);
void
cutlass_scaled_mm_azp_sm90_int8
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
,
torch
::
Tensor
const
&
azp_adj
,
std
::
optional
<
torch
::
Tensor
>
const
&
azp
,
std
::
optional
<
torch
::
Tensor
>
const
&
bias
);
void
cutlass_scaled_mm_blockwise_sm90_fp8
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
);
}
// namespace vllm
csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm90_fp8.cu
0 → 100644
View file @
afd0da21
#include "scaled_mm_kernels.hpp"
#include "scaled_mm_sm90_fp8_dispatch.cuh"
#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp"
namespace
vllm
{
void
cutlass_scaled_mm_sm90_fp8
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
,
std
::
optional
<
torch
::
Tensor
>
const
&
bias
)
{
TORCH_CHECK
(
a_scales
.
is_contiguous
()
&&
b_scales
.
is_contiguous
());
if
(
bias
)
{
TORCH_CHECK
(
bias
->
dtype
()
==
out
.
dtype
(),
"currently bias dtype must match output dtype "
,
out
.
dtype
());
return
cutlass_scaled_mm_sm90_fp8_epilogue
<
c3x
::
ScaledEpilogueBias
>
(
out
,
a
,
b
,
a_scales
,
b_scales
,
*
bias
);
}
else
{
return
cutlass_scaled_mm_sm90_fp8_epilogue
<
c3x
::
ScaledEpilogue
>
(
out
,
a
,
b
,
a_scales
,
b_scales
);
}
}
}
// namespace vllm
csrc/quantization/cutlass_w8a8/scaled_mm_
c3x_
sm90_fp8_dispatch.cuh
→
csrc/quantization/cutlass_w8a8/
c3x/
scaled_mm_sm90_fp8_dispatch.cuh
View file @
afd0da21
#pragma once
#pragma once
#include "scaled_mm_c3x.cuh"
#include "scaled_mm.cuh"
#include "cutlass_gemm_caller.cuh"
/**
/**
* This file defines Gemm kernel configurations for SM90 (fp8) based on the Gemm
* This file defines Gemm kernel configurations for SM90 (fp8) based on the Gemm
...
@@ -9,6 +10,8 @@
...
@@ -9,6 +10,8 @@
namespace
vllm
{
namespace
vllm
{
using
c3x
::
cutlass_gemm_caller
;
template
<
typename
InType
,
typename
OutType
,
template
<
typename
InType
,
typename
OutType
,
template
<
typename
,
typename
,
typename
>
typename
Epilogue
>
template
<
typename
,
typename
,
typename
>
typename
Epilogue
>
struct
sm90_fp8_config_default
{
struct
sm90_fp8_config_default
{
...
@@ -93,4 +96,25 @@ inline void cutlass_gemm_sm90_fp8_dispatch(torch::Tensor& out,
...
@@ -93,4 +96,25 @@ inline void cutlass_gemm_sm90_fp8_dispatch(torch::Tensor& out,
}
}
}
}
template
<
template
<
typename
,
typename
,
typename
>
typename
Epilogue
,
typename
...
EpilogueArgs
>
void
cutlass_scaled_mm_sm90_fp8_epilogue
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
EpilogueArgs
&&
...
epilogue_args
)
{
TORCH_CHECK
(
a
.
dtype
()
==
torch
::
kFloat8_e4m3fn
);
TORCH_CHECK
(
b
.
dtype
()
==
torch
::
kFloat8_e4m3fn
);
if
(
out
.
dtype
()
==
torch
::
kBFloat16
)
{
return
cutlass_gemm_sm90_fp8_dispatch
<
cutlass
::
float_e4m3_t
,
cutlass
::
bfloat16_t
,
Epilogue
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
epilogue_args
)...);
}
else
{
TORCH_CHECK
(
out
.
dtype
()
==
torch
::
kFloat16
);
return
cutlass_gemm_sm90_fp8_dispatch
<
cutlass
::
float_e4m3_t
,
cutlass
::
half_t
,
Epilogue
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
epilogue_args
)...);
}
}
}
// namespace vllm
}
// namespace vllm
\ No newline at end of file
csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm90_int8.cu
0 → 100644
View file @
afd0da21
#include "scaled_mm_kernels.hpp"
#include "scaled_mm_sm90_int8_dispatch.cuh"
#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp"
namespace
vllm
{
void
cutlass_scaled_mm_sm90_int8
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
,
std
::
optional
<
torch
::
Tensor
>
const
&
bias
)
{
TORCH_CHECK
(
a_scales
.
is_contiguous
()
&&
b_scales
.
is_contiguous
());
if
(
bias
)
{
TORCH_CHECK
(
bias
->
dtype
()
==
out
.
dtype
(),
"currently bias dtype must match output dtype "
,
out
.
dtype
());
return
cutlass_scaled_mm_sm90_int8_epilogue
<
c3x
::
ScaledEpilogueBias
>
(
out
,
a
,
b
,
a_scales
,
b_scales
,
*
bias
);
}
else
{
return
cutlass_scaled_mm_sm90_int8_epilogue
<
c3x
::
ScaledEpilogue
>
(
out
,
a
,
b
,
a_scales
,
b_scales
);
}
}
}
// namespace vllm
csrc/quantization/cutlass_w8a8/scaled_mm_
c3x_
sm90_int8_dispatch.cuh
→
csrc/quantization/cutlass_w8a8/
c3x/
scaled_mm_sm90_int8_dispatch.cuh
View file @
afd0da21
#pragma once
#pragma once
#include "scaled_mm_c3x.cuh"
#include "scaled_mm.cuh"
#include "cutlass_gemm_caller.cuh"
/**
/**
* This file defines Gemm kernel configurations for SM90 (int8) based on the
* This file defines Gemm kernel configurations for SM90 (int8) based on the
...
@@ -9,6 +10,8 @@
...
@@ -9,6 +10,8 @@
namespace
vllm
{
namespace
vllm
{
using
c3x
::
cutlass_gemm_caller
;
template
<
typename
InType
,
typename
OutType
,
template
<
typename
InType
,
typename
OutType
,
template
<
typename
,
typename
,
typename
>
typename
Epilogue
>
template
<
typename
,
typename
,
typename
>
typename
Epilogue
>
struct
sm90_int8_config_default
{
struct
sm90_int8_config_default
{
...
@@ -137,4 +140,24 @@ inline void cutlass_gemm_sm90_int8_dispatch(torch::Tensor& out,
...
@@ -137,4 +140,24 @@ inline void cutlass_gemm_sm90_int8_dispatch(torch::Tensor& out,
}
}
}
}
template
<
template
<
typename
,
typename
,
typename
>
typename
Epilogue
,
typename
...
EpilogueArgs
>
void
cutlass_scaled_mm_sm90_int8_epilogue
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
EpilogueArgs
&&
...
epilogue_args
)
{
TORCH_CHECK
(
a
.
dtype
()
==
torch
::
kInt8
);
TORCH_CHECK
(
b
.
dtype
()
==
torch
::
kInt8
);
if
(
out
.
dtype
()
==
torch
::
kBFloat16
)
{
return
cutlass_gemm_sm90_int8_dispatch
<
int8_t
,
cutlass
::
bfloat16_t
,
Epilogue
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
epilogue_args
)...);
}
else
{
TORCH_CHECK
(
out
.
dtype
()
==
torch
::
kFloat16
);
return
cutlass_gemm_sm90_int8_dispatch
<
int8_t
,
cutlass
::
half_t
,
Epilogue
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
epilogue_args
)...);
}
}
}
// namespace vllm
}
// namespace vllm
\ No newline at end of file
csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cu
View file @
afd0da21
...
@@ -39,7 +39,7 @@ void cutlass_scaled_mm_sm75(torch::Tensor& out, torch::Tensor const& a,
...
@@ -39,7 +39,7 @@ void cutlass_scaled_mm_sm75(torch::Tensor& out, torch::Tensor const& a,
torch
::
Tensor
const
&
b
,
torch
::
Tensor
const
&
b
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
,
torch
::
Tensor
const
&
b_scales
,
c10
::
optional
<
torch
::
Tensor
>
const
&
bias
)
{
std
::
optional
<
torch
::
Tensor
>
const
&
bias
)
{
TORCH_CHECK
(
a_scales
.
dtype
()
==
torch
::
kFloat32
);
TORCH_CHECK
(
a_scales
.
dtype
()
==
torch
::
kFloat32
);
TORCH_CHECK
(
b_scales
.
dtype
()
==
torch
::
kFloat32
);
TORCH_CHECK
(
b_scales
.
dtype
()
==
torch
::
kFloat32
);
if
(
bias
)
{
if
(
bias
)
{
...
@@ -58,8 +58,8 @@ void cutlass_scaled_mm_azp_sm75(torch::Tensor& out, torch::Tensor const& a,
...
@@ -58,8 +58,8 @@ void cutlass_scaled_mm_azp_sm75(torch::Tensor& out, torch::Tensor const& a,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
,
torch
::
Tensor
const
&
b_scales
,
torch
::
Tensor
const
&
azp_adj
,
torch
::
Tensor
const
&
azp_adj
,
c10
::
optional
<
torch
::
Tensor
>
const
&
azp
,
std
::
optional
<
torch
::
Tensor
>
const
&
azp
,
c10
::
optional
<
torch
::
Tensor
>
const
&
bias
)
{
std
::
optional
<
torch
::
Tensor
>
const
&
bias
)
{
TORCH_CHECK
(
a_scales
.
dtype
()
==
torch
::
kFloat32
);
TORCH_CHECK
(
a_scales
.
dtype
()
==
torch
::
kFloat32
);
TORCH_CHECK
(
b_scales
.
dtype
()
==
torch
::
kFloat32
);
TORCH_CHECK
(
b_scales
.
dtype
()
==
torch
::
kFloat32
);
...
@@ -94,7 +94,7 @@ void cutlass_scaled_mm_sm80(torch::Tensor& out, torch::Tensor const& a,
...
@@ -94,7 +94,7 @@ void cutlass_scaled_mm_sm80(torch::Tensor& out, torch::Tensor const& a,
torch
::
Tensor
const
&
b
,
torch
::
Tensor
const
&
b
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
,
torch
::
Tensor
const
&
b_scales
,
c10
::
optional
<
torch
::
Tensor
>
const
&
bias
)
{
std
::
optional
<
torch
::
Tensor
>
const
&
bias
)
{
TORCH_CHECK
(
a_scales
.
dtype
()
==
torch
::
kFloat32
);
TORCH_CHECK
(
a_scales
.
dtype
()
==
torch
::
kFloat32
);
TORCH_CHECK
(
b_scales
.
dtype
()
==
torch
::
kFloat32
);
TORCH_CHECK
(
b_scales
.
dtype
()
==
torch
::
kFloat32
);
if
(
bias
)
{
if
(
bias
)
{
...
@@ -113,8 +113,8 @@ void cutlass_scaled_mm_azp_sm80(torch::Tensor& out, torch::Tensor const& a,
...
@@ -113,8 +113,8 @@ void cutlass_scaled_mm_azp_sm80(torch::Tensor& out, torch::Tensor const& a,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
,
torch
::
Tensor
const
&
b_scales
,
torch
::
Tensor
const
&
azp_adj
,
torch
::
Tensor
const
&
azp_adj
,
c10
::
optional
<
torch
::
Tensor
>
const
&
azp
,
std
::
optional
<
torch
::
Tensor
>
const
&
azp
,
c10
::
optional
<
torch
::
Tensor
>
const
&
bias
)
{
std
::
optional
<
torch
::
Tensor
>
const
&
bias
)
{
TORCH_CHECK
(
a_scales
.
dtype
()
==
torch
::
kFloat32
);
TORCH_CHECK
(
a_scales
.
dtype
()
==
torch
::
kFloat32
);
TORCH_CHECK
(
b_scales
.
dtype
()
==
torch
::
kFloat32
);
TORCH_CHECK
(
b_scales
.
dtype
()
==
torch
::
kFloat32
);
...
@@ -165,7 +165,7 @@ void cutlass_scaled_mm_sm89(torch::Tensor& out, torch::Tensor const& a,
...
@@ -165,7 +165,7 @@ void cutlass_scaled_mm_sm89(torch::Tensor& out, torch::Tensor const& a,
torch
::
Tensor
const
&
b
,
torch
::
Tensor
const
&
b
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
,
torch
::
Tensor
const
&
b_scales
,
c10
::
optional
<
torch
::
Tensor
>
const
&
bias
)
{
std
::
optional
<
torch
::
Tensor
>
const
&
bias
)
{
TORCH_CHECK
(
a_scales
.
dtype
()
==
torch
::
kFloat32
);
TORCH_CHECK
(
a_scales
.
dtype
()
==
torch
::
kFloat32
);
TORCH_CHECK
(
b_scales
.
dtype
()
==
torch
::
kFloat32
);
TORCH_CHECK
(
b_scales
.
dtype
()
==
torch
::
kFloat32
);
if
(
bias
)
{
if
(
bias
)
{
...
@@ -184,8 +184,8 @@ void cutlass_scaled_mm_azp_sm89(torch::Tensor& out, torch::Tensor const& a,
...
@@ -184,8 +184,8 @@ void cutlass_scaled_mm_azp_sm89(torch::Tensor& out, torch::Tensor const& a,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
,
torch
::
Tensor
const
&
b_scales
,
torch
::
Tensor
const
&
azp_adj
,
torch
::
Tensor
const
&
azp_adj
,
c10
::
optional
<
torch
::
Tensor
>
const
&
azp
,
std
::
optional
<
torch
::
Tensor
>
const
&
azp
,
c10
::
optional
<
torch
::
Tensor
>
const
&
bias
)
{
std
::
optional
<
torch
::
Tensor
>
const
&
bias
)
{
TORCH_CHECK
(
a_scales
.
dtype
()
==
torch
::
kFloat32
);
TORCH_CHECK
(
a_scales
.
dtype
()
==
torch
::
kFloat32
);
TORCH_CHECK
(
b_scales
.
dtype
()
==
torch
::
kFloat32
);
TORCH_CHECK
(
b_scales
.
dtype
()
==
torch
::
kFloat32
);
...
...
csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu
View file @
afd0da21
#include <cudaTypedefs.h>
#include <cudaTypedefs.h>
#include "c3x/scaled_mm_kernels.hpp"
#if defined CUDA_VERSION && CUDA_VERSION >= 12000
#include "core/math.hpp"
#include "scaled_mm_c3x_sm90_fp8_dispatch.cuh"
#include "scaled_mm_c3x_sm90_int8_dispatch.cuh"
#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp"
using
namespace
vllm
;
/*
/*
This file defines quantized GEMM operations using the CUTLASS 3.x API, for
This file defines quantized GEMM operations using the CUTLASS 3.x API, for
NVIDIA GPUs with sm90a (Hopper) or later.
NVIDIA GPUs with sm90a (Hopper) or later.
*/
*/
template
<
template
<
typename
,
typename
,
typename
>
typename
Epilogue
,
typename
...
EpilogueArgs
>
void
cutlass_scaled_mm_sm90_epilogue
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
EpilogueArgs
&&
...
epilogue_args
)
{
if
(
a
.
dtype
()
==
torch
::
kInt8
)
{
TORCH_CHECK
(
b
.
dtype
()
==
torch
::
kInt8
);
if
(
out
.
dtype
()
==
torch
::
kBFloat16
)
{
return
cutlass_gemm_sm90_int8_dispatch
<
int8_t
,
cutlass
::
bfloat16_t
,
Epilogue
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
epilogue_args
)...);
}
else
{
TORCH_CHECK
(
out
.
dtype
()
==
torch
::
kFloat16
);
return
cutlass_gemm_sm90_int8_dispatch
<
int8_t
,
cutlass
::
half_t
,
Epilogue
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
epilogue_args
)...);
}
}
else
{
TORCH_CHECK
(
a
.
dtype
()
==
torch
::
kFloat8_e4m3fn
);
TORCH_CHECK
(
b
.
dtype
()
==
torch
::
kFloat8_e4m3fn
);
if
(
out
.
dtype
()
==
torch
::
kBFloat16
)
{
return
cutlass_gemm_sm90_fp8_dispatch
<
cutlass
::
float_e4m3_t
,
cutlass
::
bfloat16_t
,
Epilogue
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
epilogue_args
)...);
}
else
{
TORCH_CHECK
(
out
.
dtype
()
==
torch
::
kFloat16
);
return
cutlass_gemm_sm90_fp8_dispatch
<
cutlass
::
float_e4m3_t
,
cutlass
::
half_t
,
Epilogue
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
epilogue_args
)...);
}
}
}
void
cutlass_scaled_mm_sm90
(
torch
::
Tensor
&
c
,
torch
::
Tensor
const
&
a
,
void
cutlass_scaled_mm_sm90
(
torch
::
Tensor
&
c
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
torch
::
Tensor
const
&
b
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
,
torch
::
Tensor
const
&
b_scales
,
c10
::
optional
<
torch
::
Tensor
>
const
&
bias
)
{
std
::
optional
<
torch
::
Tensor
>
const
&
bias
)
{
TORCH_CHECK
(
a_scales
.
dtype
()
==
torch
::
kFloat32
);
TORCH_CHECK
(
a_scales
.
dtype
()
==
torch
::
kFloat32
);
TORCH_CHECK
(
b_scales
.
dtype
()
==
torch
::
kFloat32
);
TORCH_CHECK
(
b_scales
.
dtype
()
==
torch
::
kFloat32
);
if
(
bias
)
{
TORCH_CHECK
(
bias
->
dtype
()
==
c
.
dtype
(),
using
GroupShape
=
std
::
array
<
int64_t
,
2
>
;
"currently bias dtype must match output dtype "
,
c
.
dtype
());
return
cutlass_scaled_mm_sm90_epilogue
<
c3x
::
ScaledEpilogueBias
>
(
int
M
=
a
.
size
(
0
),
N
=
b
.
size
(
1
),
K
=
a
.
size
(
1
);
c
,
a
,
b
,
a_scales
,
b_scales
,
*
bias
);
GroupShape
a_scale_group_shape
=
[
&
,
&
s
=
a_scales
]()
->
GroupShape
{
if
(
s
.
numel
()
==
1
)
return
{
M
,
K
};
// tensor-wise
if
(
s
.
dim
()
==
2
)
return
{
ceil_div
(
a
.
size
(
0
),
s
.
size
(
0
)),
ceil_div
(
a
.
size
(
1
),
s
.
size
(
1
))};
TORCH_CHECK
(
false
,
"Unsupported scale shape for scale_a"
);
}();
GroupShape
b_scale_group_shape
=
[
&
,
&
s
=
b_scales
]()
->
GroupShape
{
if
(
s
.
numel
()
==
1
)
return
{
K
,
N
};
// tensor-wise
if
(
s
.
dim
()
==
2
)
return
{
ceil_div
(
b
.
size
(
0
),
s
.
size
(
0
)),
ceil_div
(
b
.
size
(
1
),
s
.
size
(
1
))};
TORCH_CHECK
(
false
,
"Unsupported scale shape for scale_b"
);
}();
if
((
a_scale_group_shape
==
GroupShape
{
M
,
K
}
||
a_scale_group_shape
==
GroupShape
{
1
,
K
})
&&
(
b_scale_group_shape
==
GroupShape
{
K
,
N
}
||
b_scale_group_shape
==
GroupShape
{
K
,
1
}))
{
// "standard per-tensor/per-token/per-channel" scaling
TORCH_CHECK
(
a_scales
.
is_contiguous
()
&&
b_scales
.
is_contiguous
());
if
(
a
.
dtype
()
==
torch
::
kFloat8_e4m3fn
)
{
vllm
::
cutlass_scaled_mm_sm90_fp8
(
c
,
a
,
b
,
a_scales
,
b_scales
,
bias
);
}
else
{
TORCH_CHECK
(
a
.
dtype
()
==
torch
::
kInt8
);
vllm
::
cutlass_scaled_mm_sm90_int8
(
c
,
a
,
b
,
a_scales
,
b_scales
,
bias
);
}
}
else
if
(
a_scale_group_shape
==
GroupShape
{
1
,
128
}
&&
b_scale_group_shape
==
GroupShape
{
128
,
128
})
{
// 1x128 per-token group scales for activations
// 128x128 blockwise scales for weights
TORCH_CHECK
(
a
.
dtype
()
==
torch
::
kFloat8_e4m3fn
&&
b
.
dtype
()
==
torch
::
kFloat8_e4m3fn
,
"Currently only FP8 is supported for A group shape 1x128 and "
"B group shape 128x128"
);
TORCH_CHECK
(
!
bias
,
"Bias not yet supported blockwise scaled_mm"
);
vllm
::
cutlass_scaled_mm_blockwise_sm90_fp8
(
c
,
a
,
b
,
a_scales
,
b_scales
);
}
else
{
}
else
{
return
cutlass_scaled_mm_sm90_epilogue
<
c3x
::
ScaledEpilogue
>
(
TORCH_CHECK
(
false
,
c
,
a
,
b
,
a_scales
,
b_scales
);
"Unsupported scale group shapes for CUTLASS 3.x GEMM.
\n
"
"a_scale_group_shape must be [1, 128], got: ["
,
a_scale_group_shape
[
0
],
", "
,
a_scale_group_shape
[
1
],
"]
\n
"
"b_scale_group_shape must be [128, 128], got: ["
,
b_scale_group_shape
[
0
],
", "
,
b_scale_group_shape
[
1
],
"]"
);
}
}
}
}
...
@@ -70,18 +73,11 @@ void cutlass_scaled_mm_azp_sm90(torch::Tensor& out, torch::Tensor const& a,
...
@@ -70,18 +73,11 @@ void cutlass_scaled_mm_azp_sm90(torch::Tensor& out, torch::Tensor const& a,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
,
torch
::
Tensor
const
&
b_scales
,
torch
::
Tensor
const
&
azp_adj
,
torch
::
Tensor
const
&
azp_adj
,
c10
::
optional
<
torch
::
Tensor
>
const
&
azp
,
std
::
optional
<
torch
::
Tensor
>
const
&
azp
,
c10
::
optional
<
torch
::
Tensor
>
const
&
bias
)
{
std
::
optional
<
torch
::
Tensor
>
const
&
bias
)
{
TORCH_CHECK
(
a_scales
.
dtype
()
==
torch
::
kFloat32
);
TORCH_CHECK
(
a_scales
.
dtype
()
==
torch
::
kFloat32
);
TORCH_CHECK
(
b_scales
.
dtype
()
==
torch
::
kFloat32
);
TORCH_CHECK
(
b_scales
.
dtype
()
==
torch
::
kFloat32
);
if
(
azp
)
{
vllm
::
cutlass_scaled_mm_azp_sm90_int8
(
out
,
a
,
b
,
a_scales
,
b_scales
,
azp_adj
,
return
cutlass_scaled_mm_sm90_epilogue
<
c3x
::
ScaledEpilogueBiasAzpToken
>
(
azp
,
bias
);
out
,
a
,
b
,
a_scales
,
b_scales
,
azp_adj
,
*
azp
,
bias
);
}
else
{
return
cutlass_scaled_mm_sm90_epilogue
<
c3x
::
ScaledEpilogueBiasAzp
>
(
out
,
a
,
b
,
a_scales
,
b_scales
,
azp_adj
,
bias
);
}
}
}
#endif
csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu
View file @
afd0da21
...
@@ -9,26 +9,26 @@ void cutlass_scaled_mm_sm75(torch::Tensor& c, torch::Tensor const& a,
...
@@ -9,26 +9,26 @@ void cutlass_scaled_mm_sm75(torch::Tensor& c, torch::Tensor const& a,
torch
::
Tensor
const
&
b
,
torch
::
Tensor
const
&
b
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
,
torch
::
Tensor
const
&
b_scales
,
c10
::
optional
<
torch
::
Tensor
>
const
&
bias
);
std
::
optional
<
torch
::
Tensor
>
const
&
bias
);
void
cutlass_scaled_mm_sm80
(
torch
::
Tensor
&
c
,
torch
::
Tensor
const
&
a
,
void
cutlass_scaled_mm_sm80
(
torch
::
Tensor
&
c
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
torch
::
Tensor
const
&
b
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
,
torch
::
Tensor
const
&
b_scales
,
c10
::
optional
<
torch
::
Tensor
>
const
&
bias
);
std
::
optional
<
torch
::
Tensor
>
const
&
bias
);
void
cutlass_scaled_mm_sm89
(
torch
::
Tensor
&
c
,
torch
::
Tensor
const
&
a
,
void
cutlass_scaled_mm_sm89
(
torch
::
Tensor
&
c
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
torch
::
Tensor
const
&
b
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
,
torch
::
Tensor
const
&
b_scales
,
c10
::
optional
<
torch
::
Tensor
>
const
&
bias
);
std
::
optional
<
torch
::
Tensor
>
const
&
bias
);
#if defined ENABLE_SCALED_MM_C3X && ENABLE_SCALED_MM_C3X
#if defined ENABLE_SCALED_MM_C3X && ENABLE_SCALED_MM_C3X
void
cutlass_scaled_mm_sm90
(
torch
::
Tensor
&
c
,
torch
::
Tensor
const
&
a
,
void
cutlass_scaled_mm_sm90
(
torch
::
Tensor
&
c
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
torch
::
Tensor
const
&
b
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
,
torch
::
Tensor
const
&
b_scales
,
c10
::
optional
<
torch
::
Tensor
>
const
&
bias
);
std
::
optional
<
torch
::
Tensor
>
const
&
bias
);
#endif
#endif
void
cutlass_scaled_mm_azp_sm75
(
torch
::
Tensor
&
c
,
torch
::
Tensor
const
&
a
,
void
cutlass_scaled_mm_azp_sm75
(
torch
::
Tensor
&
c
,
torch
::
Tensor
const
&
a
,
...
@@ -36,24 +36,24 @@ void cutlass_scaled_mm_azp_sm75(torch::Tensor& c, torch::Tensor const& a,
...
@@ -36,24 +36,24 @@ void cutlass_scaled_mm_azp_sm75(torch::Tensor& c, torch::Tensor const& a,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
,
torch
::
Tensor
const
&
b_scales
,
torch
::
Tensor
const
&
azp_adj
,
torch
::
Tensor
const
&
azp_adj
,
c10
::
optional
<
torch
::
Tensor
>
const
&
azp
,
std
::
optional
<
torch
::
Tensor
>
const
&
azp
,
c10
::
optional
<
torch
::
Tensor
>
const
&
bias
);
std
::
optional
<
torch
::
Tensor
>
const
&
bias
);
void
cutlass_scaled_mm_azp_sm80
(
torch
::
Tensor
&
c
,
torch
::
Tensor
const
&
a
,
void
cutlass_scaled_mm_azp_sm80
(
torch
::
Tensor
&
c
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
torch
::
Tensor
const
&
b
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
,
torch
::
Tensor
const
&
b_scales
,
torch
::
Tensor
const
&
azp_adj
,
torch
::
Tensor
const
&
azp_adj
,
c10
::
optional
<
torch
::
Tensor
>
const
&
azp
,
std
::
optional
<
torch
::
Tensor
>
const
&
azp
,
c10
::
optional
<
torch
::
Tensor
>
const
&
bias
);
std
::
optional
<
torch
::
Tensor
>
const
&
bias
);
void
cutlass_scaled_mm_azp_sm89
(
torch
::
Tensor
&
c
,
torch
::
Tensor
const
&
a
,
void
cutlass_scaled_mm_azp_sm89
(
torch
::
Tensor
&
c
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
torch
::
Tensor
const
&
b
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
,
torch
::
Tensor
const
&
b_scales
,
torch
::
Tensor
const
&
azp_adj
,
torch
::
Tensor
const
&
azp_adj
,
c10
::
optional
<
torch
::
Tensor
>
const
&
azp
,
std
::
optional
<
torch
::
Tensor
>
const
&
azp
,
c10
::
optional
<
torch
::
Tensor
>
const
&
bias
);
std
::
optional
<
torch
::
Tensor
>
const
&
bias
);
#if defined CUDA_VERSION && CUDA_VERSION >= 12000
#if defined CUDA_VERSION && CUDA_VERSION >= 12000
void
cutlass_scaled_mm_azp_sm90
(
torch
::
Tensor
&
c
,
torch
::
Tensor
const
&
a
,
void
cutlass_scaled_mm_azp_sm90
(
torch
::
Tensor
&
c
,
torch
::
Tensor
const
&
a
,
...
@@ -61,8 +61,8 @@ void cutlass_scaled_mm_azp_sm90(torch::Tensor& c, torch::Tensor const& a,
...
@@ -61,8 +61,8 @@ void cutlass_scaled_mm_azp_sm90(torch::Tensor& c, torch::Tensor const& a,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
,
torch
::
Tensor
const
&
b_scales
,
torch
::
Tensor
const
&
azp_adj
,
torch
::
Tensor
const
&
azp_adj
,
c10
::
optional
<
torch
::
Tensor
>
const
&
azp
,
std
::
optional
<
torch
::
Tensor
>
const
&
azp
,
c10
::
optional
<
torch
::
Tensor
>
const
&
bias
);
std
::
optional
<
torch
::
Tensor
>
const
&
bias
);
#endif
#endif
bool
cutlass_scaled_mm_supports_fp8
(
int64_t
cuda_device_capability
)
{
bool
cutlass_scaled_mm_supports_fp8
(
int64_t
cuda_device_capability
)
{
...
@@ -81,23 +81,33 @@ bool cutlass_scaled_mm_supports_fp8(int64_t cuda_device_capability) {
...
@@ -81,23 +81,33 @@ bool cutlass_scaled_mm_supports_fp8(int64_t cuda_device_capability) {
return
false
;
return
false
;
}
}
bool
cutlass_scaled_mm_supports_block_fp8
(
int64_t
cuda_device_capability
)
{
// CUTLASS block-quantized FP8 kernels need at least CUDA 12.0
// and at least SM90 (Hopper)
#if defined CUDA_VERSION
if
(
cuda_device_capability
>=
90
)
{
return
CUDA_VERSION
>=
12000
;
}
#endif
return
false
;
}
void
cutlass_scaled_mm
(
torch
::
Tensor
&
c
,
torch
::
Tensor
const
&
a
,
void
cutlass_scaled_mm
(
torch
::
Tensor
&
c
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
,
torch
::
Tensor
const
&
b_scales
,
c10
::
optional
<
torch
::
Tensor
>
const
&
bias
)
{
std
::
optional
<
torch
::
Tensor
>
const
&
bias
)
{
// Checks for conformality
// Checks for conformality
TORCH_CHECK
(
a
.
dim
()
==
2
&&
b
.
dim
()
==
2
&&
c
.
dim
()
==
2
);
TORCH_CHECK
(
a
.
dim
()
==
2
&&
b
.
dim
()
==
2
&&
c
.
dim
()
==
2
);
TORCH_CHECK
(
c
.
size
(
0
)
==
a
.
size
(
0
)
&&
a
.
size
(
1
)
==
b
.
size
(
0
)
&&
TORCH_CHECK
(
c
.
size
(
0
)
==
a
.
size
(
0
)
&&
a
.
size
(
1
)
==
b
.
size
(
0
)
&&
b
.
size
(
1
)
==
c
.
size
(
1
));
b
.
size
(
1
)
==
c
.
size
(
1
));
TORCH_CHECK
(
a_scales
.
numel
()
==
1
||
a_scales
.
numel
()
==
a
.
size
(
0
));
TORCH_CHECK
(
b_scales
.
numel
()
==
1
||
b_scales
.
numel
()
==
b
.
size
(
1
));
// Check for strides and alignment
// Check for strides and alignment
TORCH_CHECK
(
a
.
stride
(
1
)
==
1
&&
c
.
stride
(
1
)
==
1
);
// Row-major
TORCH_CHECK
(
a
.
stride
(
1
)
==
1
&&
c
.
stride
(
1
)
==
1
);
// Row-major
TORCH_CHECK
(
b
.
stride
(
0
)
==
1
);
// Column-major
TORCH_CHECK
(
b
.
stride
(
0
)
==
1
);
// Column-major
TORCH_CHECK
(
c
.
stride
(
0
)
%
16
==
0
&&
TORCH_CHECK
(
c
.
stride
(
0
)
%
16
==
0
&&
b
.
stride
(
1
)
%
16
==
0
);
// 16 Byte Alignment
b
.
stride
(
1
)
%
16
==
0
);
// 16 Byte Alignment
TORCH_CHECK
(
a_scales
.
is_contiguous
()
&&
b_scales
.
is_contiguous
());
if
(
bias
)
{
if
(
bias
)
{
TORCH_CHECK
(
bias
->
numel
()
==
b
.
size
(
1
)
&&
bias
->
is_contiguous
()
&&
TORCH_CHECK
(
bias
->
numel
()
==
b
.
size
(
1
)
&&
bias
->
is_contiguous
()
&&
...
@@ -148,8 +158,8 @@ void cutlass_scaled_mm_azp(torch::Tensor& c, torch::Tensor const& a,
...
@@ -148,8 +158,8 @@ void cutlass_scaled_mm_azp(torch::Tensor& c, torch::Tensor const& a,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
,
torch
::
Tensor
const
&
b_scales
,
torch
::
Tensor
const
&
azp_adj
,
torch
::
Tensor
const
&
azp_adj
,
c10
::
optional
<
torch
::
Tensor
>
const
&
azp
,
std
::
optional
<
torch
::
Tensor
>
const
&
azp
,
c10
::
optional
<
torch
::
Tensor
>
const
&
bias
)
{
std
::
optional
<
torch
::
Tensor
>
const
&
bias
)
{
// Checks for conformality
// Checks for conformality
TORCH_CHECK
(
a
.
dim
()
==
2
&&
b
.
dim
()
==
2
&&
c
.
dim
()
==
2
);
TORCH_CHECK
(
a
.
dim
()
==
2
&&
b
.
dim
()
==
2
&&
c
.
dim
()
==
2
);
TORCH_CHECK
(
c
.
size
(
0
)
==
a
.
size
(
0
)
&&
a
.
size
(
1
)
==
b
.
size
(
0
)
&&
TORCH_CHECK
(
c
.
size
(
0
)
==
a
.
size
(
0
)
&&
a
.
size
(
1
)
==
b
.
size
(
0
)
&&
...
@@ -215,4 +225,4 @@ void cutlass_scaled_mm_azp(torch::Tensor& c, torch::Tensor const& a,
...
@@ -215,4 +225,4 @@ void cutlass_scaled_mm_azp(torch::Tensor& c, torch::Tensor const& a,
"No compiled cutlass_scaled_mm_azp for a compute capability less than "
"No compiled cutlass_scaled_mm_azp for a compute capability less than "
"CUDA device capability: "
,
"CUDA device capability: "
,
version_num
);
version_num
);
}
}
\ No newline at end of file
csrc/quantization/gptq_marlin/gptq_marlin.cu
View file @
afd0da21
...
@@ -173,8 +173,8 @@ dequant<half, vllm::kU4B8.id()>(int q) {
...
@@ -173,8 +173,8 @@ dequant<half, vllm::kU4B8.id()>(int q) {
const
int
HI
=
0x00f000f0
;
const
int
HI
=
0x00f000f0
;
const
int
EX
=
0x64006400
;
const
int
EX
=
0x64006400
;
// Guarantee that the `(a & b) | c` operations are LOP3s.
// Guarantee that the `(a & b) | c` operations are LOP3s.
int
lo
=
lop3
<
(
0xf0
&
0xcc
)
|
0xaa
>
(
q
,
LO
,
EX
);
int
lo
=
lop3
<
(
0xf0
&
0xcc
)
|
0xaa
>
(
q
,
LO
,
EX
);
int
hi
=
lop3
<
(
0xf0
&
0xcc
)
|
0xaa
>
(
q
,
HI
,
EX
);
int
hi
=
lop3
<
(
0xf0
&
0xcc
)
|
0xaa
>
(
q
,
HI
,
EX
);
// We want signed int4 outputs, hence we fuse the `-8` symmetric zero point
// We want signed int4 outputs, hence we fuse the `-8` symmetric zero point
// directly into `SUB` and `ADD`.
// directly into `SUB` and `ADD`.
const
int
SUB
=
0x64086408
;
const
int
SUB
=
0x64086408
;
...
@@ -197,9 +197,9 @@ dequant<nv_bfloat16, vllm::kU4B8.id()>(int q) {
...
@@ -197,9 +197,9 @@ dequant<nv_bfloat16, vllm::kU4B8.id()>(int q) {
// Guarantee that the `(a & b) | c` operations are LOP3s.
// Guarantee that the `(a & b) | c` operations are LOP3s.
int
lo
=
lop3
<
(
0xf0
&
0xcc
)
|
0xaa
>
(
q
,
MASK
,
EX
);
int
lo
=
lop3
<
(
0xf0
&
0xcc
)
|
0xaa
>
(
q
,
MASK
,
EX
);
q
>>=
4
;
q
>>=
4
;
int
hi
=
lop3
<
(
0xf0
&
0xcc
)
|
0xaa
>
(
q
,
MASK
,
EX
);
int
hi
=
lop3
<
(
0xf0
&
0xcc
)
|
0xaa
>
(
q
,
MASK
,
EX
);
typename
ScalarType
<
nv_bfloat16
>::
FragB
frag_b
;
typename
ScalarType
<
nv_bfloat16
>::
FragB
frag_b
;
static
constexpr
uint32_t
MUL
=
0x3F803F80
;
static
constexpr
uint32_t
MUL
=
0x3F803F80
;
...
@@ -221,8 +221,8 @@ dequant<half, vllm::kU4.id()>(int q) {
...
@@ -221,8 +221,8 @@ dequant<half, vllm::kU4.id()>(int q) {
const
int
HI
=
0x00f000f0
;
const
int
HI
=
0x00f000f0
;
const
int
EX
=
0x64006400
;
const
int
EX
=
0x64006400
;
// Guarantee that the `(a & b) | c` operations are LOP3s.
// Guarantee that the `(a & b) | c` operations are LOP3s.
int
lo
=
lop3
<
(
0xf0
&
0xcc
)
|
0xaa
>
(
q
,
LO
,
EX
);
int
lo
=
lop3
<
(
0xf0
&
0xcc
)
|
0xaa
>
(
q
,
LO
,
EX
);
int
hi
=
lop3
<
(
0xf0
&
0xcc
)
|
0xaa
>
(
q
,
HI
,
EX
);
int
hi
=
lop3
<
(
0xf0
&
0xcc
)
|
0xaa
>
(
q
,
HI
,
EX
);
const
int
SUB
=
0x64006400
;
const
int
SUB
=
0x64006400
;
const
int
MUL
=
0x2c002c00
;
const
int
MUL
=
0x2c002c00
;
...
@@ -244,9 +244,9 @@ dequant<nv_bfloat16, vllm::kU4.id()>(int q) {
...
@@ -244,9 +244,9 @@ dequant<nv_bfloat16, vllm::kU4.id()>(int q) {
// Guarantee that the `(a & b) | c` operations are LOP3s.
// Guarantee that the `(a & b) | c` operations are LOP3s.
int
lo
=
lop3
<
(
0xf0
&
0xcc
)
|
0xaa
>
(
q
,
MASK
,
EX
);
int
lo
=
lop3
<
(
0xf0
&
0xcc
)
|
0xaa
>
(
q
,
MASK
,
EX
);
q
>>=
4
;
q
>>=
4
;
int
hi
=
lop3
<
(
0xf0
&
0xcc
)
|
0xaa
>
(
q
,
MASK
,
EX
);
int
hi
=
lop3
<
(
0xf0
&
0xcc
)
|
0xaa
>
(
q
,
MASK
,
EX
);
typename
ScalarType
<
nv_bfloat16
>::
FragB
frag_b
;
typename
ScalarType
<
nv_bfloat16
>::
FragB
frag_b
;
static
constexpr
uint32_t
MUL
=
0x3F803F80
;
static
constexpr
uint32_t
MUL
=
0x3F803F80
;
...
@@ -834,6 +834,7 @@ __global__ void Marlin(
...
@@ -834,6 +834,7 @@ __global__ void Marlin(
int4
*
sh_g_idx
=
sh_b
+
(
stages
*
b_sh_stage
);
int4
*
sh_g_idx
=
sh_b
+
(
stages
*
b_sh_stage
);
int4
*
sh_zp
=
sh_g_idx
+
(
stages
*
g_idx_stage
);
int4
*
sh_zp
=
sh_g_idx
+
(
stages
*
g_idx_stage
);
int4
*
sh_s
=
sh_zp
+
(
stages
*
zp_sh_stage
);
int4
*
sh_s
=
sh_zp
+
(
stages
*
zp_sh_stage
);
int4
*
sh_red
=
sh_s
+
(
stages
*
s_sh_stage
);
// Register storage for double buffer of shared memory reads.
// Register storage for double buffer of shared memory reads.
FragA
frag_a
[
2
][
thread_m_blocks
];
FragA
frag_a
[
2
][
thread_m_blocks
];
...
@@ -932,11 +933,11 @@ __global__ void Marlin(
...
@@ -932,11 +933,11 @@ __global__ void Marlin(
int4
*
sh_s_stage
=
sh_s
+
s_sh_stage
*
pipe
;
int4
*
sh_s_stage
=
sh_s
+
s_sh_stage
*
pipe
;
if
constexpr
(
group_blocks
>=
thread_k_blocks
)
{
if
constexpr
(
group_blocks
>=
thread_k_blocks
)
{
if
(
s_sh_wr_pred
)
{
cp_async4
(
&
sh_s_stage
[
s_sh_wr
],
&
scales_ptr
[
s_gl_rd
]);
}
// Only fetch scales if this tile starts a new group
// Only fetch scales if this tile starts a new group
if
(
pipe
%
(
group_blocks
/
thread_k_blocks
)
==
0
)
{
if
((
pipe
+
1
)
%
(
group_blocks
/
thread_k_blocks
)
==
0
)
{
if
(
s_sh_wr_pred
)
{
cp_async4
(
&
sh_s_stage
[
s_sh_wr
],
&
scales_ptr
[
s_gl_rd
]);
}
s_gl_rd
+=
s_gl_rd_delta
;
s_gl_rd
+=
s_gl_rd_delta
;
}
}
}
else
{
}
else
{
...
@@ -1038,9 +1039,7 @@ __global__ void Marlin(
...
@@ -1038,9 +1039,7 @@ __global__ void Marlin(
// No act-order case
// No act-order case
if
constexpr
(
group_blocks
!=
-
1
)
{
if
constexpr
(
group_blocks
!=
-
1
)
{
if
constexpr
(
group_blocks
>=
thread_k_blocks
)
{
if
constexpr
(
group_blocks
>=
thread_k_blocks
)
{
int4
*
sh_s_stage
=
int4
*
sh_s_stage
=
sh_s
+
s_sh_stage
*
pipe
;
sh_s
+
s_sh_stage
*
((
group_blocks
/
thread_k_blocks
)
*
(
pipe
/
(
group_blocks
/
thread_k_blocks
)));
reinterpret_cast
<
int4
*>
(
&
frag_s
[
k
%
2
])[
0
]
=
sh_s_stage
[
s_sh_rd
];
reinterpret_cast
<
int4
*>
(
&
frag_s
[
k
%
2
])[
0
]
=
sh_s_stage
[
s_sh_rd
];
}
else
{
}
else
{
int
warp_id
=
threadIdx
.
x
/
32
;
int
warp_id
=
threadIdx
.
x
/
32
;
...
@@ -1339,15 +1338,15 @@ __global__ void Marlin(
...
@@ -1339,15 +1338,15 @@ __global__ void Marlin(
int
red_sh_wr
=
int
red_sh_wr
=
red_sh_delta
*
j
+
(
red_sh_rd
-
red_sh_stride
*
i
);
red_sh_delta
*
j
+
(
red_sh_rd
-
red_sh_stride
*
i
);
if
(
i
<
red_off
)
{
if
(
i
<
red_off
)
{
float
*
c_rd
=
float
*
c_rd
=
reinterpret_cast
<
float
*>
(
reinterpret_cast
<
float
*>
(
&
sh
[
red_sh_delta
*
j
+
red_sh_rd
]);
&
sh_red
[
red_sh_delta
*
j
+
red_sh_rd
]);
float
*
c_wr
=
reinterpret_cast
<
float
*>
(
&
sh
[
red_sh_wr
]);
float
*
c_wr
=
reinterpret_cast
<
float
*>
(
&
sh
_red
[
red_sh_wr
]);
#pragma unroll
#pragma unroll
for
(
int
k
=
0
;
k
<
4
;
k
++
)
for
(
int
k
=
0
;
k
<
4
;
k
++
)
reinterpret_cast
<
FragC
*>
(
frag_c
)[
4
*
2
*
m_block
+
j
][
k
]
+=
reinterpret_cast
<
FragC
*>
(
frag_c
)[
4
*
2
*
m_block
+
j
][
k
]
+=
c_rd
[
k
]
+
c_wr
[
k
];
c_rd
[
k
]
+
c_wr
[
k
];
}
}
sh
[
red_sh_wr
]
=
sh
_red
[
red_sh_wr
]
=
reinterpret_cast
<
int4
*>
(
&
frag_c
)[
4
*
2
*
m_block
+
j
];
reinterpret_cast
<
int4
*>
(
&
frag_c
)[
4
*
2
*
m_block
+
j
];
}
}
}
}
...
@@ -1357,7 +1356,7 @@ __global__ void Marlin(
...
@@ -1357,7 +1356,7 @@ __global__ void Marlin(
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
4
*
2
;
i
++
)
{
for
(
int
i
=
0
;
i
<
4
*
2
;
i
++
)
{
float
*
c_rd
=
float
*
c_rd
=
reinterpret_cast
<
float
*>
(
&
sh
[
red_sh_delta
*
i
+
red_sh_rd
]);
reinterpret_cast
<
float
*>
(
&
sh
_red
[
red_sh_delta
*
i
+
red_sh_rd
]);
#pragma unroll
#pragma unroll
for
(
int
j
=
0
;
j
<
4
;
j
++
)
for
(
int
j
=
0
;
j
<
4
;
j
++
)
reinterpret_cast
<
FragC
*>
(
frag_c
)[
4
*
2
*
m_block
+
i
][
j
]
+=
reinterpret_cast
<
FragC
*>
(
frag_c
)[
4
*
2
*
m_block
+
i
][
j
]
+=
...
@@ -1397,7 +1396,7 @@ __global__ void Marlin(
...
@@ -1397,7 +1396,7 @@ __global__ void Marlin(
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
thread_m_blocks
*
4
;
i
++
)
{
for
(
int
i
=
0
;
i
<
thread_m_blocks
*
4
;
i
++
)
{
cp_async4_pred
(
cp_async4_pred
(
&
sh
[
c_sh_wr
+
c_sh_wr_delta
*
i
],
&
sh
_red
[
c_sh_wr
+
c_sh_wr_delta
*
i
],
&
C
[
c_gl_wr
+
c_gl_wr_delta_o
*
(
i
/
2
)
+
&
C
[
c_gl_wr
+
c_gl_wr_delta_o
*
(
i
/
2
)
+
c_gl_wr_delta_i
*
(
i
%
2
)],
c_gl_wr_delta_i
*
(
i
%
2
)],
i
<
(
thread_m_blocks
-
1
)
*
4
||
8
*
(
i
/
2
)
+
row
<
prob_m
);
i
<
(
thread_m_blocks
-
1
)
*
4
||
8
*
(
i
/
2
)
+
row
<
prob_m
);
...
@@ -1410,7 +1409,7 @@ __global__ void Marlin(
...
@@ -1410,7 +1409,7 @@ __global__ void Marlin(
for
(
int
i
=
0
;
i
<
thread_m_blocks
*
4
;
i
++
)
{
for
(
int
i
=
0
;
i
<
thread_m_blocks
*
4
;
i
++
)
{
if
(
i
<
(
thread_m_blocks
-
1
)
*
4
||
8
*
(
i
/
2
)
+
row
<
prob_m
)
{
if
(
i
<
(
thread_m_blocks
-
1
)
*
4
||
8
*
(
i
/
2
)
+
row
<
prob_m
)
{
if
(
!
first
)
{
if
(
!
first
)
{
int4
c_red
=
sh
[
c_sh_wr
+
i
*
c_sh_wr_delta
];
int4
c_red
=
sh
_red
[
c_sh_wr
+
i
*
c_sh_wr_delta
];
#pragma unroll
#pragma unroll
for
(
int
j
=
0
;
j
<
2
*
4
;
j
++
)
{
for
(
int
j
=
0
;
j
<
2
*
4
;
j
++
)
{
reinterpret_cast
<
float
*>
(
reinterpret_cast
<
float
*>
(
...
@@ -1461,10 +1460,10 @@ __global__ void Marlin(
...
@@ -1461,10 +1460,10 @@ __global__ void Marlin(
float
*
frag_c_ptr
=
reinterpret_cast
<
float
*>
(
&
frag_c
);
float
*
frag_c_ptr
=
reinterpret_cast
<
float
*>
(
&
frag_c
);
#pragma unroll
#pragma unroll
for
(
int
k
=
0
;
k
<
th_size
;
k
++
)
{
for
(
int
k
=
0
;
k
<
th_size
;
k
++
)
{
sh
[
threadIdx
.
x
]
=
sh
_red
[
threadIdx
.
x
]
=
C_tmp
[
c_cur_offset
+
active_threads
*
k
+
threadIdx
.
x
];
C_tmp
[
c_cur_offset
+
active_threads
*
k
+
threadIdx
.
x
];
float
*
sh_c_ptr
=
reinterpret_cast
<
float
*>
(
&
sh
[
threadIdx
.
x
]);
float
*
sh_c_ptr
=
reinterpret_cast
<
float
*>
(
&
sh
_red
[
threadIdx
.
x
]);
#pragma unroll
#pragma unroll
for
(
int
f
=
0
;
f
<
4
;
f
++
)
{
for
(
int
f
=
0
;
f
<
4
;
f
++
)
{
frag_c_ptr
[
k
*
4
+
f
]
+=
sh_c_ptr
[
f
];
frag_c_ptr
[
k
*
4
+
f
]
+=
sh_c_ptr
[
f
];
...
@@ -1515,7 +1514,7 @@ __global__ void Marlin(
...
@@ -1515,7 +1514,7 @@ __global__ void Marlin(
res
=
__hmul2
(
res
,
s
[
0
]);
res
=
__hmul2
(
res
,
s
[
0
]);
}
}
((
scalar_t2
*
)
sh
)[
idx
]
=
res
;
((
scalar_t2
*
)
sh
_red
)[
idx
]
=
res
;
};
};
if
(
threadIdx
.
x
/
32
<
thread_n_blocks
/
4
)
{
if
(
threadIdx
.
x
/
32
<
thread_n_blocks
/
4
)
{
...
@@ -1543,7 +1542,7 @@ __global__ void Marlin(
...
@@ -1543,7 +1542,7 @@ __global__ void Marlin(
i
<
div_ceil
(
16
*
thread_m_blocks
,
threads
/
(
2
*
thread_n_blocks
));
i
<
div_ceil
(
16
*
thread_m_blocks
,
threads
/
(
2
*
thread_n_blocks
));
i
++
)
{
i
++
)
{
if
(
c_gl_wr
<
c_gl_wr_end
)
{
if
(
c_gl_wr
<
c_gl_wr_end
)
{
C
[
c_gl_wr
]
=
sh
[
c_sh_rd
];
C
[
c_gl_wr
]
=
sh
_red
[
c_sh_rd
];
c_gl_wr
+=
c_gl_wr_delta
;
c_gl_wr
+=
c_gl_wr_delta
;
c_sh_rd
+=
c_sh_rd_delta
;
c_sh_rd
+=
c_sh_rd_delta
;
}
}
...
@@ -1865,9 +1864,12 @@ bool is_valid_cache_size(thread_config_t const& th_config, int max_m_blocks,
...
@@ -1865,9 +1864,12 @@ bool is_valid_cache_size(thread_config_t const& th_config, int max_m_blocks,
float
pipe_size
=
(
a_size
+
b_size
)
*
pipe_stages
;
float
pipe_size
=
(
a_size
+
b_size
)
*
pipe_stages
;
float
reduce_size
=
max
(
th_config
.
num_threads
*
32
*
4
,
(
tb_n
/
64
)
*
32
*
(
tb_max_m
/
16
)
*
4
*
2
*
4
*
2
);
TORCH_CHECK
(
max_shared_mem
/
2
>
scales_cache_size
);
// Sanity
TORCH_CHECK
(
max_shared_mem
/
2
>
scales_cache_size
);
// Sanity
return
pipe_size
<
0.95
f
*
(
max_shared_mem
-
scales_cache_size
);
return
pipe_size
+
reduce_size
<
0.95
f
*
(
max_shared_mem
-
scales_cache_size
);
}
}
bool
is_valid_config
(
thread_config_t
const
&
th_config
,
int
max_m_blocks
,
bool
is_valid_config
(
thread_config_t
const
&
th_config
,
int
max_m_blocks
,
...
...
csrc/quantization/machete/generate.py
View file @
afd0da21
...
@@ -63,7 +63,7 @@ torch::Tensor mm_dispatch_{{type_sig}}(MMArgs args) {
...
@@ -63,7 +63,7 @@ torch::Tensor mm_dispatch_{{type_sig}}(MMArgs args) {
static inline std::optional<at::ScalarType> maybe_scalartype(
static inline std::optional<at::ScalarType> maybe_scalartype(
c10
::optional<at::Tensor> const& t) {
std
::optional<at::Tensor> const& t) {
if (!t) {
if (!t) {
return std::nullopt;
return std::nullopt;
} else {
} else {
...
@@ -189,7 +189,7 @@ using Kernel_{{type_sig}} = MacheteKernelTemplate<
...
@@ -189,7 +189,7 @@ using Kernel_{{type_sig}} = MacheteKernelTemplate<
{{DataTypeTag[t.b_group_zeropoint]}}, // GroupZeroT
{{DataTypeTag[t.b_group_zeropoint]}}, // GroupZeroT
{{DataTypeTag[t.b_channel_scale]}}, // ChannelScaleT
{{DataTypeTag[t.b_channel_scale]}}, // ChannelScaleT
{{DataTypeTag[t.a_token_scale]}}, // TokenScaleT
{{DataTypeTag[t.a_token_scale]}}, // TokenScaleT
cutlass::gemm::KernelTmaWarpSpecializedCooperative
MixedInput
,
cutlass::gemm::KernelTmaWarpSpecializedCooperative,
Sch>;
Sch>;
{% for sch in schs %}
{% for sch in schs %}
...
@@ -223,7 +223,7 @@ torch::Tensor prepack_B_dispatch(PrepackBArgs args) {
...
@@ -223,7 +223,7 @@ torch::Tensor prepack_B_dispatch(PrepackBArgs args) {
{{DataTypeTag[t.convert]}}, // ElementConvert
{{DataTypeTag[t.convert]}}, // ElementConvert
{{DataTypeTag[t.accumulator]}}, // Accumulator
{{DataTypeTag[t.accumulator]}}, // Accumulator
cutlass::layout::ColumnMajor,
cutlass::layout::ColumnMajor,
cutlass::gemm::KernelTmaWarpSpecializedCooperative
MixedInput
>
cutlass::gemm::KernelTmaWarpSpecializedCooperative>
>(args.B);
>(args.B);
}
}
{%- endfor %}
{%- endfor %}
...
@@ -239,7 +239,7 @@ torch::Tensor prepack_B_dispatch(PrepackBArgs args) {
...
@@ -239,7 +239,7 @@ torch::Tensor prepack_B_dispatch(PrepackBArgs args) {
}; // namespace machete
}; // namespace machete
"""
"""
TmaMI
=
MixedInputKernelScheduleType
.
TmaWarpSpecializedCooperative
MixedInput
TmaMI
=
MixedInputKernelScheduleType
.
TmaWarpSpecializedCooperative
TmaCoop
=
EpilogueScheduleType
.
TmaWarpSpecializedCooperative
TmaCoop
=
EpilogueScheduleType
.
TmaWarpSpecializedCooperative
...
@@ -300,7 +300,7 @@ def generate_sch_sig(schedule_config: ScheduleConfig) -> str:
...
@@ -300,7 +300,7 @@ def generate_sch_sig(schedule_config: ScheduleConfig) -> str:
# mostly unique shorter sch_sig
# mostly unique shorter sch_sig
def
generate_terse_sch_sig
(
schedule_config
:
ScheduleConfig
)
->
str
:
def
generate_terse_sch_sig
(
schedule_config
:
ScheduleConfig
)
->
str
:
kernel_terse_names_replace
=
{
kernel_terse_names_replace
=
{
"KernelTmaWarpSpecializedCooperative
MixedInput_
"
:
"TmaMI_"
,
"KernelTmaWarpSpecializedCooperative"
:
"TmaMI_"
,
"TmaWarpSpecializedCooperative_"
:
"TmaCoop_"
,
"TmaWarpSpecializedCooperative_"
:
"TmaCoop_"
,
"StreamKScheduler"
:
"streamK"
,
"StreamKScheduler"
:
"streamK"
,
}
}
...
...
csrc/quantization/machete/machete_collective_builder.cuh
View file @
afd0da21
...
@@ -18,16 +18,14 @@ struct VLLMCollectiveBuilder<
...
@@ -18,16 +18,14 @@ struct VLLMCollectiveBuilder<
ElementAccumulator
,
TileShape_MNK
,
ClusterShape_MNK
,
StageCountType
,
ElementAccumulator
,
TileShape_MNK
,
ClusterShape_MNK
,
StageCountType
,
KernelScheduleType
,
KernelScheduleType
,
cute
::
enable_if_t
<
(
cute
::
enable_if_t
<
(
cute
::
is_same_v
<
KernelScheduleType
,
KernelTmaWarpSpecialized
>
||
cute
::
is_same_v
<
KernelScheduleType
,
KernelTmaWarpSpecializedPingpong
>
||
cute
::
is_same_v
<
KernelScheduleType
,
cute
::
is_same_v
<
KernelScheduleType
,
KernelTmaWarpSpecializedMixedInput
>
||
KernelTmaWarpSpecializedCooperative
>
)
>>
{
cute
::
is_same_v
<
KernelScheduleType
,
KernelTmaWarpSpecializedPingpongMixedInput
>
||
cute
::
is_same_v
<
KernelScheduleType
,
KernelTmaWarpSpecializedCooperativeMixedInput
>
)
>>
{
using
CollectiveOp
=
machete
::
MacheteCollectiveMma
<
using
CollectiveOp
=
machete
::
MacheteCollectiveMma
<
ElementPairA_
,
GmemLayoutA_
,
AlignmentA
,
ElementPairB_
,
GmemLayoutB_
,
ElementPairA_
,
GmemLayoutA_
,
AlignmentA
,
ElementPairB_
,
GmemLayoutB_
,
AlignmentB
,
ElementAccumulator
,
TileShape_MNK
,
ClusterShape_MNK
,
AlignmentB
,
ElementAccumulator
,
TileShape_MNK
,
ClusterShape_MNK
,
StageCountType
,
KernelScheduleType
>
;
StageCountType
,
KernelScheduleType
>
;
};
};
};
// namespace cutlass::gemm::collective
};
// namespace cutlass::gemm::collective
\ No newline at end of file
csrc/quantization/machete/machete_mainloop.cuh
View file @
afd0da21
...
@@ -66,13 +66,11 @@ struct MacheteCollectiveMma {
...
@@ -66,13 +66,11 @@ struct MacheteCollectiveMma {
using
Schedule
=
KernelScheduleType
;
using
Schedule
=
KernelScheduleType
;
static_assert
(
static_assert
(
cute
::
is_same_v
<
Schedule
,
KernelTmaWarpSpecialized
>
||
cute
::
is_same_v
<
Schedule
,
KernelTmaWarpSpecialized
>
||
cute
::
is_same_v
<
Schedule
,
KernelTmaWarpSpecializedMixedInput
>
||
cute
::
is_same_v
<
Schedule
,
KernelTmaWarpSpecialized
>
||
cute
::
is_same_v
<
Schedule
,
KernelTmaWarpSpecializedPingpong
>
||
cute
::
is_same_v
<
Schedule
,
KernelTmaWarpSpecializedPingpong
>
||
cute
::
is_same_v
<
Schedule
,
KernelTmaWarpSpecializedPingpong
>
||
cute
::
is_same_v
<
Schedule
,
KernelTmaWarpSpecializedPingpongMixedInput
>
||
cute
::
is_same_v
<
Schedule
,
KernelTmaWarpSpecializedCooperative
>
||
cute
::
is_same_v
<
Schedule
,
KernelTmaWarpSpecializedCooperative
>
||
cute
::
is_same_v
<
Schedule
,
cute
::
is_same_v
<
Schedule
,
KernelTmaWarpSpecializedCooperative
>
,
KernelTmaWarpSpecializedCooperativeMixedInput
>
,
"KernelSchedule must be one of the warp specialized policies"
);
"KernelSchedule must be one of the warp specialized policies"
);
public:
public:
...
@@ -113,8 +111,7 @@ struct MacheteCollectiveMma {
...
@@ -113,8 +111,7 @@ struct MacheteCollectiveMma {
// For coop schedules we have two warp groups cooperatively issuing wgmma
// For coop schedules we have two warp groups cooperatively issuing wgmma
// instructions so we use 2 atoms along the M dim (one for each warpgroup)
// instructions so we use 2 atoms along the M dim (one for each warpgroup)
using
AtomLayoutMNK
=
cute
::
conditional_t
<
using
AtomLayoutMNK
=
cute
::
conditional_t
<
cute
::
is_same_v
<
KernelScheduleType
,
cute
::
is_same_v
<
KernelScheduleType
,
KernelTmaWarpSpecializedCooperative
>
,
KernelTmaWarpSpecializedCooperativeMixedInput
>
,
Layout
<
Shape
<
_2
,
_1
,
_1
>>
,
Layout
<
Shape
<
_1
,
_1
,
_1
>>>
;
Layout
<
Shape
<
_2
,
_1
,
_1
>>
,
Layout
<
Shape
<
_1
,
_1
,
_1
>>>
;
using
TiledMma
=
decltype
(
cute
::
make_tiled_mma
(
using
TiledMma
=
decltype
(
cute
::
make_tiled_mma
(
...
@@ -275,6 +272,10 @@ struct MacheteCollectiveMma {
...
@@ -275,6 +272,10 @@ struct MacheteCollectiveMma {
using
PipelineState
=
cutlass
::
PipelineState
<
DispatchPolicy
::
Stages
>
;
using
PipelineState
=
cutlass
::
PipelineState
<
DispatchPolicy
::
Stages
>
;
using
PipelineParams
=
typename
MainloopPipeline
::
Params
;
using
PipelineParams
=
typename
MainloopPipeline
::
Params
;
// One threads per CTA are producers (1 for operand tile)
static
constexpr
int
NumProducerThreadEvents
=
1
;
using
ScaleTileShape
=
decltype
(
make_shape
(
shape
<
0
>
(
TileShape
{}),
using
ScaleTileShape
=
decltype
(
make_shape
(
shape
<
0
>
(
TileShape
{}),
shape
<
1
>
(
SmemLayoutAtomScale
{})));
shape
<
1
>
(
SmemLayoutAtomScale
{})));
...
...
csrc/quantization/machete/machete_mm_kernel.cuh
View file @
afd0da21
...
@@ -183,11 +183,11 @@ struct MacheteKernelTemplate {
...
@@ -183,11 +183,11 @@ struct MacheteKernelTemplate {
torch
::
Tensor
const
&
A
,
// MxK matrix
torch
::
Tensor
const
&
A
,
// MxK matrix
torch
::
Tensor
const
&
B
,
// KxN prepacked matrix
torch
::
Tensor
const
&
B
,
// KxN prepacked matrix
torch
::
Tensor
&
D
,
// MxN matrix
torch
::
Tensor
&
D
,
// MxN matrix
c10
::
optional
<
torch
::
Tensor
>
const
&
maybe_g_scales
,
// scale_KxN matrix
std
::
optional
<
torch
::
Tensor
>
const
&
maybe_g_scales
,
// scale_KxN matrix
c10
::
optional
<
torch
::
Tensor
>
const
&
maybe_g_zeros
,
// scale_KxN matrix
std
::
optional
<
torch
::
Tensor
>
const
&
maybe_g_zeros
,
// scale_KxN matrix
c10
::
optional
<
int64_t
>
maybe_group_size
,
std
::
optional
<
int64_t
>
maybe_group_size
,
c10
::
optional
<
torch
::
Tensor
>
const
&
maybe_ch_scales
,
// len N vector
std
::
optional
<
torch
::
Tensor
>
const
&
maybe_ch_scales
,
// len N vector
c10
::
optional
<
torch
::
Tensor
>
const
&
maybe_tok_scales
)
// len M vector
std
::
optional
<
torch
::
Tensor
>
const
&
maybe_tok_scales
)
// len M vector
{
{
static_assert
(
!
with_group_zeropoints
||
with_group_scales
);
static_assert
(
!
with_group_zeropoints
||
with_group_scales
);
...
...
csrc/quantization/machete/machete_mm_launcher.cuh
View file @
afd0da21
...
@@ -13,23 +13,23 @@ struct MMArgs {
...
@@ -13,23 +13,23 @@ struct MMArgs {
torch
::
Tensor
const
&
A
;
torch
::
Tensor
const
&
A
;
torch
::
Tensor
const
&
B
;
torch
::
Tensor
const
&
B
;
vllm
::
ScalarType
const
&
b_type
;
vllm
::
ScalarType
const
&
b_type
;
c10
::
optional
<
at
::
ScalarType
>
const
&
maybe_out_type
;
std
::
optional
<
at
::
ScalarType
>
const
&
maybe_out_type
;
c10
::
optional
<
torch
::
Tensor
>
const
&
maybe_group_scales
;
std
::
optional
<
torch
::
Tensor
>
const
&
maybe_group_scales
;
c10
::
optional
<
torch
::
Tensor
>
const
&
maybe_group_zeros
;
std
::
optional
<
torch
::
Tensor
>
const
&
maybe_group_zeros
;
c10
::
optional
<
int64_t
>
maybe_group_size
;
std
::
optional
<
int64_t
>
maybe_group_size
;
c10
::
optional
<
torch
::
Tensor
>
const
&
maybe_channel_scales
;
std
::
optional
<
torch
::
Tensor
>
const
&
maybe_channel_scales
;
c10
::
optional
<
torch
::
Tensor
>
const
&
maybe_token_scales
;
std
::
optional
<
torch
::
Tensor
>
const
&
maybe_token_scales
;
c10
::
optional
<
std
::
string
>
maybe_schedule
;
std
::
optional
<
std
::
string
>
maybe_schedule
;
};
};
struct
SupportedSchedulesArgs
{
struct
SupportedSchedulesArgs
{
at
::
ScalarType
a_type
;
at
::
ScalarType
a_type
;
vllm
::
ScalarType
b_type
;
vllm
::
ScalarType
b_type
;
c10
::
optional
<
at
::
ScalarType
>
maybe_group_scales_type
;
std
::
optional
<
at
::
ScalarType
>
maybe_group_scales_type
;
c10
::
optional
<
at
::
ScalarType
>
maybe_group_zeros_type
;
std
::
optional
<
at
::
ScalarType
>
maybe_group_zeros_type
;
c10
::
optional
<
at
::
ScalarType
>
maybe_channel_scales_type
;
std
::
optional
<
at
::
ScalarType
>
maybe_channel_scales_type
;
c10
::
optional
<
at
::
ScalarType
>
maybe_token_scales_type
;
std
::
optional
<
at
::
ScalarType
>
maybe_token_scales_type
;
c10
::
optional
<
at
::
ScalarType
>
maybe_out_type
;
std
::
optional
<
at
::
ScalarType
>
maybe_out_type
;
};
};
torch
::
Tensor
mm_dispatch
(
MMArgs
args
);
torch
::
Tensor
mm_dispatch
(
MMArgs
args
);
...
...
csrc/quantization/machete/machete_prepack_launcher.cuh
View file @
afd0da21
...
@@ -10,7 +10,7 @@ struct PrepackBArgs {
...
@@ -10,7 +10,7 @@ struct PrepackBArgs {
torch
::
Tensor
const
&
B
;
torch
::
Tensor
const
&
B
;
at
::
ScalarType
a_type
;
at
::
ScalarType
a_type
;
vllm
::
ScalarType
b_type
;
vllm
::
ScalarType
b_type
;
c10
::
optional
<
at
::
ScalarType
>
maybe_group_scales_type
;
std
::
optional
<
at
::
ScalarType
>
maybe_group_scales_type
;
};
};
template
<
typename
PrepackedLayoutB
>
template
<
typename
PrepackedLayoutB
>
...
...
csrc/quantization/machete/machete_prepacked_layout.cuh
View file @
afd0da21
...
@@ -98,8 +98,7 @@ struct PrepackedLayoutBTemplate {
...
@@ -98,8 +98,7 @@ struct PrepackedLayoutBTemplate {
// For coop schedules we have two warp groups cooperatively issuing wgmma
// For coop schedules we have two warp groups cooperatively issuing wgmma
// instructions so we use 2 atoms along the M dim (one for each warpgroup)
// instructions so we use 2 atoms along the M dim (one for each warpgroup)
using
AtomLayoutMNK
=
cute
::
conditional_t
<
using
AtomLayoutMNK
=
cute
::
conditional_t
<
cute
::
is_same_v
<
KernelSchedule
,
cute
::
is_same_v
<
KernelSchedule
,
KernelTmaWarpSpecializedCooperative
>
,
KernelTmaWarpSpecializedCooperativeMixedInput
>
,
Layout
<
Shape
<
_2
,
_1
,
_1
>>
,
Layout
<
Shape
<
_1
,
_1
,
_1
>>>
;
Layout
<
Shape
<
_2
,
_1
,
_1
>>
,
Layout
<
Shape
<
_1
,
_1
,
_1
>>>
;
using
TiledMma
=
decltype
(
cute
::
make_tiled_mma
(
using
TiledMma
=
decltype
(
cute
::
make_tiled_mma
(
...
@@ -247,4 +246,4 @@ struct PrepackedLayoutBTemplate {
...
@@ -247,4 +246,4 @@ struct PrepackedLayoutBTemplate {
}
}
};
};
};
// namespace machete
};
// namespace machete
\ No newline at end of file
csrc/quantization/machete/machete_pytorch.cu
View file @
afd0da21
...
@@ -10,11 +10,11 @@ using namespace vllm;
...
@@ -10,11 +10,11 @@ using namespace vllm;
std
::
vector
<
std
::
string
>
supported_schedules
(
std
::
vector
<
std
::
string
>
supported_schedules
(
at
::
ScalarType
a_type
,
int64_t
b_type_id
,
at
::
ScalarType
a_type
,
int64_t
b_type_id
,
c10
::
optional
<
at
::
ScalarType
>
maybe_group_scales_type
,
std
::
optional
<
at
::
ScalarType
>
maybe_group_scales_type
,
c10
::
optional
<
at
::
ScalarType
>
maybe_group_zeros_type
,
std
::
optional
<
at
::
ScalarType
>
maybe_group_zeros_type
,
c10
::
optional
<
at
::
ScalarType
>
maybe_channel_scales_type
,
std
::
optional
<
at
::
ScalarType
>
maybe_channel_scales_type
,
c10
::
optional
<
at
::
ScalarType
>
maybe_token_scales_type
,
std
::
optional
<
at
::
ScalarType
>
maybe_token_scales_type
,
c10
::
optional
<
at
::
ScalarType
>
maybe_out_type
)
{
std
::
optional
<
at
::
ScalarType
>
maybe_out_type
)
{
ScalarType
const
b_type
=
ScalarType
::
from_id
(
b_type_id
);
ScalarType
const
b_type
=
ScalarType
::
from_id
(
b_type_id
);
return
supported_schedules_dispatch
({
return
supported_schedules_dispatch
({
.
a_type
=
a_type
,
.
a_type
=
a_type
,
...
@@ -29,13 +29,13 @@ std::vector<std::string> supported_schedules(
...
@@ -29,13 +29,13 @@ std::vector<std::string> supported_schedules(
torch
::
Tensor
mm
(
torch
::
Tensor
const
&
A
,
torch
::
Tensor
const
&
B
,
torch
::
Tensor
mm
(
torch
::
Tensor
const
&
A
,
torch
::
Tensor
const
&
B
,
int64_t
b_type_id
,
int64_t
b_type_id
,
c10
::
optional
<
at
::
ScalarType
>
const
&
maybe_out_type
,
std
::
optional
<
at
::
ScalarType
>
const
&
maybe_out_type
,
c10
::
optional
<
torch
::
Tensor
>
const
&
maybe_group_scales
,
std
::
optional
<
torch
::
Tensor
>
const
&
maybe_group_scales
,
c10
::
optional
<
torch
::
Tensor
>
const
&
maybe_group_zeros
,
std
::
optional
<
torch
::
Tensor
>
const
&
maybe_group_zeros
,
c10
::
optional
<
int64_t
>
maybe_group_size
,
std
::
optional
<
int64_t
>
maybe_group_size
,
c10
::
optional
<
torch
::
Tensor
>
const
&
maybe_channel_scales
,
std
::
optional
<
torch
::
Tensor
>
const
&
maybe_channel_scales
,
c10
::
optional
<
torch
::
Tensor
>
const
&
maybe_token_scales
,
std
::
optional
<
torch
::
Tensor
>
const
&
maybe_token_scales
,
c10
::
optional
<
std
::
string
>
maybe_schedule
)
{
std
::
optional
<
std
::
string
>
maybe_schedule
)
{
ScalarType
const
b_type
=
ScalarType
::
from_id
(
b_type_id
);
ScalarType
const
b_type
=
ScalarType
::
from_id
(
b_type_id
);
return
mm_dispatch
({.
A
=
A
,
return
mm_dispatch
({.
A
=
A
,
.
B
=
B
,
.
B
=
B
,
...
@@ -51,7 +51,7 @@ torch::Tensor mm(torch::Tensor const& A, torch::Tensor const& B,
...
@@ -51,7 +51,7 @@ torch::Tensor mm(torch::Tensor const& A, torch::Tensor const& B,
torch
::
Tensor
prepack_B
(
torch
::
Tensor
prepack_B
(
torch
::
Tensor
const
&
B
,
at
::
ScalarType
const
&
a_type
,
int64_t
b_type_id
,
torch
::
Tensor
const
&
B
,
at
::
ScalarType
const
&
a_type
,
int64_t
b_type_id
,
c10
::
optional
<
at
::
ScalarType
>
const
&
maybe_group_scales_type
)
{
std
::
optional
<
at
::
ScalarType
>
const
&
maybe_group_scales_type
)
{
ScalarType
const
b_type
=
ScalarType
::
from_id
(
b_type_id
);
ScalarType
const
b_type
=
ScalarType
::
from_id
(
b_type_id
);
return
prepack_B_dispatch
(
return
prepack_B_dispatch
(
{.
B
=
B
,
{.
B
=
B
,
...
...
csrc/quantization/marlin/dense/marlin_cuda_kernel.cu
View file @
afd0da21
...
@@ -96,8 +96,8 @@ __device__ inline FragB dequant(int q) {
...
@@ -96,8 +96,8 @@ __device__ inline FragB dequant(int q) {
const
int
HI
=
0x00f000f0
;
const
int
HI
=
0x00f000f0
;
const
int
EX
=
0x64006400
;
const
int
EX
=
0x64006400
;
// Guarantee that the `(a & b) | c` operations are LOP3s.
// Guarantee that the `(a & b) | c` operations are LOP3s.
int
lo
=
lop3
<
(
0xf0
&
0xcc
)
|
0xaa
>
(
q
,
LO
,
EX
);
int
lo
=
lop3
<
(
0xf0
&
0xcc
)
|
0xaa
>
(
q
,
LO
,
EX
);
int
hi
=
lop3
<
(
0xf0
&
0xcc
)
|
0xaa
>
(
q
,
HI
,
EX
);
int
hi
=
lop3
<
(
0xf0
&
0xcc
)
|
0xaa
>
(
q
,
HI
,
EX
);
// We want signed int4 outputs, hence we fuse the `-8` symmetric zero point
// We want signed int4 outputs, hence we fuse the `-8` symmetric zero point
// directly into `SUB` and `ADD`.
// directly into `SUB` and `ADD`.
const
int
SUB
=
0x64086408
;
const
int
SUB
=
0x64086408
;
...
...
Prev
1
2
3
4
5
6
7
8
9
10
…
30
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment