Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
afd0da21
Commit
afd0da21
authored
Feb 03, 2025
by
zhuwenwen
Browse files
Merge tag 'v0.7.1' into v0.7.1-dev
parents
1a11f127
4f4d427a
Changes
587
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
497 additions
and
171 deletions
+497
-171
csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm90_fp8.cu
...tization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm90_fp8.cu
+24
-0
csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm90_fp8_dispatch.cuh
...utlass_w8a8/c3x/scaled_mm_blockwise_sm90_fp8_dispatch.cuh
+168
-0
csrc/quantization/cutlass_w8a8/c3x/scaled_mm_kernels.hpp
csrc/quantization/cutlass_w8a8/c3x/scaled_mm_kernels.hpp
+33
-0
csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm90_fp8.cu
csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm90_fp8.cu
+24
-0
csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm90_fp8_dispatch.cuh
...tization/cutlass_w8a8/c3x/scaled_mm_sm90_fp8_dispatch.cuh
+25
-1
csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm90_int8.cu
csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm90_int8.cu
+24
-0
csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm90_int8_dispatch.cuh
...ization/cutlass_w8a8/c3x/scaled_mm_sm90_int8_dispatch.cuh
+24
-1
csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cu
csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cu
+9
-9
csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu
csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu
+56
-60
csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu
csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu
+29
-19
csrc/quantization/gptq_marlin/gptq_marlin.cu
csrc/quantization/gptq_marlin/gptq_marlin.cu
+29
-27
csrc/quantization/machete/generate.py
csrc/quantization/machete/generate.py
+5
-5
csrc/quantization/machete/machete_collective_builder.cuh
csrc/quantization/machete/machete_collective_builder.cuh
+4
-6
csrc/quantization/machete/machete_mainloop.cuh
csrc/quantization/machete/machete_mainloop.cuh
+8
-7
csrc/quantization/machete/machete_mm_kernel.cuh
csrc/quantization/machete/machete_mm_kernel.cuh
+5
-5
csrc/quantization/machete/machete_mm_launcher.cuh
csrc/quantization/machete/machete_mm_launcher.cuh
+12
-12
csrc/quantization/machete/machete_prepack_launcher.cuh
csrc/quantization/machete/machete_prepack_launcher.cuh
+1
-1
csrc/quantization/machete/machete_prepacked_layout.cuh
csrc/quantization/machete/machete_prepacked_layout.cuh
+2
-3
csrc/quantization/machete/machete_pytorch.cu
csrc/quantization/machete/machete_pytorch.cu
+13
-13
csrc/quantization/marlin/dense/marlin_cuda_kernel.cu
csrc/quantization/marlin/dense/marlin_cuda_kernel.cu
+2
-2
No files found.
Too many changes to show.
To preserve performance only
587 of 587+
files are displayed.
Plain diff
Email patch
csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm90_fp8.cu
0 → 100644
View file @
afd0da21
#include "scaled_mm_kernels.hpp"
#include "scaled_mm_blockwise_sm90_fp8_dispatch.cuh"
#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp"
namespace
vllm
{
void
cutlass_scaled_mm_blockwise_sm90_fp8
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
)
{
if
(
out
.
dtype
()
==
torch
::
kBFloat16
)
{
cutlass_gemm_blockwise_sm90_fp8_dispatch
<
cutlass
::
bfloat16_t
>
(
out
,
a
,
b
,
a_scales
,
b_scales
);
}
else
{
TORCH_CHECK
(
out
.
dtype
()
==
torch
::
kFloat16
);
cutlass_gemm_blockwise_sm90_fp8_dispatch
<
cutlass
::
half_t
>
(
out
,
a
,
b
,
a_scales
,
b_scales
);
}
}
}
// namespace vllm
\ No newline at end of file
csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm90_fp8_dispatch.cuh
0 → 100644
View file @
afd0da21
#pragma once
#include "cutlass/cutlass.h"
#include "cutlass/numeric_types.h"
#include "cute/tensor.hpp"
#include "cutlass/tensor_ref.h"
#include "cutlass/gemm/dispatch_policy.hpp"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include "cutlass/gemm/kernel/gemm_universal.hpp"
#include "cutlass/gemm/kernel/tile_scheduler_params.h"
#include "cutlass/epilogue/dispatch_policy.hpp"
#include "cutlass/epilogue/collective/collective_builder.hpp"
#include "cutlass_extensions/gemm/dispatch_policy.hpp"
#include "cutlass_extensions/gemm/collective/collective_builder.hpp"
#include "cutlass_gemm_caller.cuh"
namespace
vllm
{
using
namespace
cute
;
template
<
typename
OutType
,
int
GroupSizeM_
,
int
GroupSizeN_
,
int
GroupSizeK_
,
int
TileSizeM_
=
128
,
class
ClusterShape
=
Shape
<
_1
,
_2
,
_1
>
>
struct
cutlass_3x_gemm_fp8_blockwise
{
using
GroupSizeM
=
Int
<
GroupSizeM_
>
;
using
GroupSizeN
=
Int
<
GroupSizeN_
>
;
using
GroupSizeK
=
Int
<
GroupSizeK_
>
;
using
TileSizeM
=
Int
<
TileSizeM_
>
;
static_assert
(
TileSizeM_
%
GroupSizeM_
==
0
,
"TileSizeM must be a multiple of GroupSizeM"
);
using
ElementAB
=
cutlass
::
float_e4m3_t
;
using
ElementA
=
ElementAB
;
using
LayoutA
=
cutlass
::
layout
::
RowMajor
;
static
constexpr
int
AlignmentA
=
128
/
cutlass
::
sizeof_bits
<
ElementA
>::
value
;
using
ElementB
=
ElementAB
;
using
LayoutB
=
cutlass
::
layout
::
ColumnMajor
;
static
constexpr
int
AlignmentB
=
128
/
cutlass
::
sizeof_bits
<
ElementB
>::
value
;
using
ElementD
=
OutType
;
using
StrideD
=
Stride
<
int64_t
,
Int
<
1
>
,
Int
<
0
>>
;
static
constexpr
int
AlignmentD
=
128
/
cutlass
::
sizeof_bits
<
ElementD
>::
value
;
using
ElementC
=
void
;
using
StrideC
=
StrideD
;
static
constexpr
int
AlignmentC
=
AlignmentD
;
using
ElementAccumulator
=
float
;
using
ElementBlockScale
=
float
;
using
ElementCompute
=
float
;
using
ArchTag
=
cutlass
::
arch
::
Sm90
;
using
OperatorClass
=
cutlass
::
arch
::
OpClassTensorOp
;
using
TileShape
=
Shape
<
TileSizeM
,
GroupSizeN
,
GroupSizeK
>
;
using
KernelSchedule
=
cutlass
::
gemm
::
KernelTmaWarpSpecializedCooperativeFP8BlockScaledSubGroupMAccum
<
GroupSizeM_
>
;
using
EpilogueSchedule
=
cutlass
::
epilogue
::
TmaWarpSpecializedCooperative
;
using
EpilogueTileType
=
cutlass
::
epilogue
::
collective
::
EpilogueTileAuto
;
using
StoreEpilogueCompute
=
typename
cutlass
::
epilogue
::
fusion
::
Sm90EVT
<
cutlass
::
epilogue
::
fusion
::
Sm90AccFetch
>
;
using
CollectiveEpilogue
=
typename
cutlass
::
epilogue
::
collective
::
CollectiveBuilder
<
ArchTag
,
OperatorClass
,
TileShape
,
ClusterShape
,
EpilogueTileType
,
ElementAccumulator
,
ElementCompute
,
ElementC
,
StrideC
,
AlignmentC
,
ElementD
,
StrideD
,
AlignmentD
,
EpilogueSchedule
,
StoreEpilogueCompute
>::
CollectiveOp
;
using
CollectiveMainloop
=
typename
cutlass
::
gemm
::
collective
::
CollectiveBuilder
<
ArchTag
,
OperatorClass
,
ElementA
,
LayoutA
,
AlignmentA
,
ElementB
,
LayoutB
,
AlignmentB
,
ElementAccumulator
,
TileShape
,
ClusterShape
,
cutlass
::
gemm
::
collective
::
StageCountAutoCarveout
<
static_cast
<
int
>
(
sizeof
(
typename
CollectiveEpilogue
::
SharedStorage
))
>
,
KernelSchedule
>::
CollectiveOp
;
using
KernelType
=
enable_sm90_or_later
<
cutlass
::
gemm
::
kernel
::
GemmUniversal
<
Shape
<
int
,
int
,
int
,
int
>
,
CollectiveMainloop
,
CollectiveEpilogue
,
cutlass
::
gemm
::
PersistentScheduler
>>
;
struct
GemmKernel
:
public
KernelType
{};
using
StrideA
=
typename
GemmKernel
::
StrideA
;
using
StrideB
=
typename
GemmKernel
::
StrideB
;
};
template
<
typename
Gemm
>
void
cutlass_gemm_caller_blockwise
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
)
{
using
GemmKernel
=
typename
Gemm
::
GemmKernel
;
using
ElementAB
=
typename
Gemm
::
ElementAB
;
using
ElementD
=
typename
Gemm
::
ElementD
;
auto
prob_shape
=
c3x
::
get_problem_shape
(
a
,
b
);
int32_t
m
=
get
<
0
>
(
prob_shape
),
n
=
get
<
1
>
(
prob_shape
),
k
=
get
<
2
>
(
prob_shape
);
int64_t
lda
=
a
.
stride
(
0
);
int64_t
ldb
=
b
.
stride
(
1
);
int64_t
ldc
=
out
.
stride
(
0
);
using
StrideA
=
Stride
<
int64_t
,
Int
<
1
>
,
int64_t
>
;
using
StrideB
=
Stride
<
int64_t
,
Int
<
1
>
,
int64_t
>
;
using
StrideC
=
typename
Gemm
::
StrideC
;
StrideA
a_stride
{
lda
,
Int
<
1
>
{},
0
};
StrideB
b_stride
{
ldb
,
Int
<
1
>
{},
0
};
StrideC
c_stride
{
ldc
,
Int
<
1
>
{},
Int
<
0
>
{}};
auto
a_ptr
=
static_cast
<
ElementAB
*>
(
a
.
data_ptr
());
auto
b_ptr
=
static_cast
<
ElementAB
*>
(
b
.
data_ptr
());
auto
a_scales_ptr
=
static_cast
<
float
*>
(
a_scales
.
data_ptr
());
auto
b_scales_ptr
=
static_cast
<
float
*>
(
b_scales
.
data_ptr
());
// Check is the t is contiguous and is 1D or 2D with one of the dimensions
// being 1 (i.e. a row or column vector)
auto
is_contiguous_vector
=
[](
const
torch
::
Tensor
&
t
)
{
auto
t_sizes
=
t
.
sizes
();
return
t
.
is_contiguous
()
&&
(
t
.
dim
()
==
1
||
(
t
.
dim
()
==
2
&&
*
std
::
min_element
(
t_sizes
.
begin
(),
t_sizes
.
end
())
==
1
));
};
// TODO(lucas): lets clean-up the kernel so that we pass in Strides so
// we don't have to deal with enforcing implicit layouts
TORCH_CHECK
(
a_scales
.
size
(
0
)
==
m
/
Gemm
::
GroupSizeM
::
value
);
TORCH_CHECK
(
a_scales
.
size
(
1
)
==
k
/
Gemm
::
GroupSizeK
::
value
);
TORCH_CHECK
(
a_scales
.
stride
(
0
)
==
1
||
is_contiguous_vector
(
a_scales
),
"a_scales must be M major"
);
TORCH_CHECK
(
b_scales
.
size
(
0
)
==
k
/
Gemm
::
GroupSizeK
::
value
);
TORCH_CHECK
(
b_scales
.
size
(
1
)
==
n
/
Gemm
::
GroupSizeN
::
value
);
TORCH_CHECK
(
b_scales
.
stride
(
0
)
==
1
||
is_contiguous_vector
(
b_scales
),
"b_scales must be K major"
);
typename
GemmKernel
::
MainloopArguments
mainloop_args
{
a_ptr
,
a_stride
,
b_ptr
,
b_stride
,
a_scales_ptr
,
b_scales_ptr
};
auto
c_ptr
=
static_cast
<
ElementD
*>
(
out
.
data_ptr
());
typename
GemmKernel
::
EpilogueArguments
epilogue_args
{
{},
c_ptr
,
c_stride
,
c_ptr
,
c_stride
};
c3x
::
cutlass_gemm_caller
<
GemmKernel
>
(
a
.
device
(),
prob_shape
,
mainloop_args
,
epilogue_args
);
}
template
<
typename
OutType
>
void
cutlass_gemm_blockwise_sm90_fp8_dispatch
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
)
{
cutlass_gemm_caller_blockwise
<
cutlass_3x_gemm_fp8_blockwise
<
OutType
,
1
,
128
,
128
>>
(
out
,
a
,
b
,
a_scales
,
b_scales
);
}
}
// namespace vllm
\ No newline at end of file
csrc/quantization/cutlass_w8a8/c3x/scaled_mm_kernels.hpp
0 → 100644
View file @
afd0da21
#pragma once
#include <torch/all.h>
namespace
vllm
{
void
cutlass_scaled_mm_sm90_fp8
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
,
std
::
optional
<
torch
::
Tensor
>
const
&
bias
);
void
cutlass_scaled_mm_sm90_int8
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
,
std
::
optional
<
torch
::
Tensor
>
const
&
bias
);
void
cutlass_scaled_mm_azp_sm90_int8
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
,
torch
::
Tensor
const
&
azp_adj
,
std
::
optional
<
torch
::
Tensor
>
const
&
azp
,
std
::
optional
<
torch
::
Tensor
>
const
&
bias
);
void
cutlass_scaled_mm_blockwise_sm90_fp8
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
);
}
// namespace vllm
csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm90_fp8.cu
0 → 100644
View file @
afd0da21
#include "scaled_mm_kernels.hpp"
#include "scaled_mm_sm90_fp8_dispatch.cuh"
#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp"
namespace
vllm
{
void
cutlass_scaled_mm_sm90_fp8
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
,
std
::
optional
<
torch
::
Tensor
>
const
&
bias
)
{
TORCH_CHECK
(
a_scales
.
is_contiguous
()
&&
b_scales
.
is_contiguous
());
if
(
bias
)
{
TORCH_CHECK
(
bias
->
dtype
()
==
out
.
dtype
(),
"currently bias dtype must match output dtype "
,
out
.
dtype
());
return
cutlass_scaled_mm_sm90_fp8_epilogue
<
c3x
::
ScaledEpilogueBias
>
(
out
,
a
,
b
,
a_scales
,
b_scales
,
*
bias
);
}
else
{
return
cutlass_scaled_mm_sm90_fp8_epilogue
<
c3x
::
ScaledEpilogue
>
(
out
,
a
,
b
,
a_scales
,
b_scales
);
}
}
}
// namespace vllm
csrc/quantization/cutlass_w8a8/scaled_mm_
c3x_
sm90_fp8_dispatch.cuh
→
csrc/quantization/cutlass_w8a8/
c3x/
scaled_mm_sm90_fp8_dispatch.cuh
View file @
afd0da21
#pragma once
#include "scaled_mm_c3x.cuh"
#include "scaled_mm.cuh"
#include "cutlass_gemm_caller.cuh"
/**
* This file defines Gemm kernel configurations for SM90 (fp8) based on the Gemm
...
...
@@ -9,6 +10,8 @@
namespace
vllm
{
using
c3x
::
cutlass_gemm_caller
;
template
<
typename
InType
,
typename
OutType
,
template
<
typename
,
typename
,
typename
>
typename
Epilogue
>
struct
sm90_fp8_config_default
{
...
...
@@ -93,4 +96,25 @@ inline void cutlass_gemm_sm90_fp8_dispatch(torch::Tensor& out,
}
}
template
<
template
<
typename
,
typename
,
typename
>
typename
Epilogue
,
typename
...
EpilogueArgs
>
void
cutlass_scaled_mm_sm90_fp8_epilogue
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
EpilogueArgs
&&
...
epilogue_args
)
{
TORCH_CHECK
(
a
.
dtype
()
==
torch
::
kFloat8_e4m3fn
);
TORCH_CHECK
(
b
.
dtype
()
==
torch
::
kFloat8_e4m3fn
);
if
(
out
.
dtype
()
==
torch
::
kBFloat16
)
{
return
cutlass_gemm_sm90_fp8_dispatch
<
cutlass
::
float_e4m3_t
,
cutlass
::
bfloat16_t
,
Epilogue
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
epilogue_args
)...);
}
else
{
TORCH_CHECK
(
out
.
dtype
()
==
torch
::
kFloat16
);
return
cutlass_gemm_sm90_fp8_dispatch
<
cutlass
::
float_e4m3_t
,
cutlass
::
half_t
,
Epilogue
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
epilogue_args
)...);
}
}
}
// namespace vllm
\ No newline at end of file
csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm90_int8.cu
0 → 100644
View file @
afd0da21
#include "scaled_mm_kernels.hpp"
#include "scaled_mm_sm90_int8_dispatch.cuh"
#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp"
namespace
vllm
{
void
cutlass_scaled_mm_sm90_int8
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
,
std
::
optional
<
torch
::
Tensor
>
const
&
bias
)
{
TORCH_CHECK
(
a_scales
.
is_contiguous
()
&&
b_scales
.
is_contiguous
());
if
(
bias
)
{
TORCH_CHECK
(
bias
->
dtype
()
==
out
.
dtype
(),
"currently bias dtype must match output dtype "
,
out
.
dtype
());
return
cutlass_scaled_mm_sm90_int8_epilogue
<
c3x
::
ScaledEpilogueBias
>
(
out
,
a
,
b
,
a_scales
,
b_scales
,
*
bias
);
}
else
{
return
cutlass_scaled_mm_sm90_int8_epilogue
<
c3x
::
ScaledEpilogue
>
(
out
,
a
,
b
,
a_scales
,
b_scales
);
}
}
}
// namespace vllm
csrc/quantization/cutlass_w8a8/scaled_mm_
c3x_
sm90_int8_dispatch.cuh
→
csrc/quantization/cutlass_w8a8/
c3x/
scaled_mm_sm90_int8_dispatch.cuh
View file @
afd0da21
#pragma once
#include "scaled_mm_c3x.cuh"
#include "scaled_mm.cuh"
#include "cutlass_gemm_caller.cuh"
/**
* This file defines Gemm kernel configurations for SM90 (int8) based on the
...
...
@@ -9,6 +10,8 @@
namespace
vllm
{
using
c3x
::
cutlass_gemm_caller
;
template
<
typename
InType
,
typename
OutType
,
template
<
typename
,
typename
,
typename
>
typename
Epilogue
>
struct
sm90_int8_config_default
{
...
...
@@ -137,4 +140,24 @@ inline void cutlass_gemm_sm90_int8_dispatch(torch::Tensor& out,
}
}
template
<
template
<
typename
,
typename
,
typename
>
typename
Epilogue
,
typename
...
EpilogueArgs
>
void
cutlass_scaled_mm_sm90_int8_epilogue
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
EpilogueArgs
&&
...
epilogue_args
)
{
TORCH_CHECK
(
a
.
dtype
()
==
torch
::
kInt8
);
TORCH_CHECK
(
b
.
dtype
()
==
torch
::
kInt8
);
if
(
out
.
dtype
()
==
torch
::
kBFloat16
)
{
return
cutlass_gemm_sm90_int8_dispatch
<
int8_t
,
cutlass
::
bfloat16_t
,
Epilogue
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
epilogue_args
)...);
}
else
{
TORCH_CHECK
(
out
.
dtype
()
==
torch
::
kFloat16
);
return
cutlass_gemm_sm90_int8_dispatch
<
int8_t
,
cutlass
::
half_t
,
Epilogue
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
epilogue_args
)...);
}
}
}
// namespace vllm
\ No newline at end of file
csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cu
View file @
afd0da21
...
...
@@ -39,7 +39,7 @@ void cutlass_scaled_mm_sm75(torch::Tensor& out, torch::Tensor const& a,
torch
::
Tensor
const
&
b
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
,
c10
::
optional
<
torch
::
Tensor
>
const
&
bias
)
{
std
::
optional
<
torch
::
Tensor
>
const
&
bias
)
{
TORCH_CHECK
(
a_scales
.
dtype
()
==
torch
::
kFloat32
);
TORCH_CHECK
(
b_scales
.
dtype
()
==
torch
::
kFloat32
);
if
(
bias
)
{
...
...
@@ -58,8 +58,8 @@ void cutlass_scaled_mm_azp_sm75(torch::Tensor& out, torch::Tensor const& a,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
,
torch
::
Tensor
const
&
azp_adj
,
c10
::
optional
<
torch
::
Tensor
>
const
&
azp
,
c10
::
optional
<
torch
::
Tensor
>
const
&
bias
)
{
std
::
optional
<
torch
::
Tensor
>
const
&
azp
,
std
::
optional
<
torch
::
Tensor
>
const
&
bias
)
{
TORCH_CHECK
(
a_scales
.
dtype
()
==
torch
::
kFloat32
);
TORCH_CHECK
(
b_scales
.
dtype
()
==
torch
::
kFloat32
);
...
...
@@ -94,7 +94,7 @@ void cutlass_scaled_mm_sm80(torch::Tensor& out, torch::Tensor const& a,
torch
::
Tensor
const
&
b
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
,
c10
::
optional
<
torch
::
Tensor
>
const
&
bias
)
{
std
::
optional
<
torch
::
Tensor
>
const
&
bias
)
{
TORCH_CHECK
(
a_scales
.
dtype
()
==
torch
::
kFloat32
);
TORCH_CHECK
(
b_scales
.
dtype
()
==
torch
::
kFloat32
);
if
(
bias
)
{
...
...
@@ -113,8 +113,8 @@ void cutlass_scaled_mm_azp_sm80(torch::Tensor& out, torch::Tensor const& a,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
,
torch
::
Tensor
const
&
azp_adj
,
c10
::
optional
<
torch
::
Tensor
>
const
&
azp
,
c10
::
optional
<
torch
::
Tensor
>
const
&
bias
)
{
std
::
optional
<
torch
::
Tensor
>
const
&
azp
,
std
::
optional
<
torch
::
Tensor
>
const
&
bias
)
{
TORCH_CHECK
(
a_scales
.
dtype
()
==
torch
::
kFloat32
);
TORCH_CHECK
(
b_scales
.
dtype
()
==
torch
::
kFloat32
);
...
...
@@ -165,7 +165,7 @@ void cutlass_scaled_mm_sm89(torch::Tensor& out, torch::Tensor const& a,
torch
::
Tensor
const
&
b
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
,
c10
::
optional
<
torch
::
Tensor
>
const
&
bias
)
{
std
::
optional
<
torch
::
Tensor
>
const
&
bias
)
{
TORCH_CHECK
(
a_scales
.
dtype
()
==
torch
::
kFloat32
);
TORCH_CHECK
(
b_scales
.
dtype
()
==
torch
::
kFloat32
);
if
(
bias
)
{
...
...
@@ -184,8 +184,8 @@ void cutlass_scaled_mm_azp_sm89(torch::Tensor& out, torch::Tensor const& a,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
,
torch
::
Tensor
const
&
azp_adj
,
c10
::
optional
<
torch
::
Tensor
>
const
&
azp
,
c10
::
optional
<
torch
::
Tensor
>
const
&
bias
)
{
std
::
optional
<
torch
::
Tensor
>
const
&
azp
,
std
::
optional
<
torch
::
Tensor
>
const
&
bias
)
{
TORCH_CHECK
(
a_scales
.
dtype
()
==
torch
::
kFloat32
);
TORCH_CHECK
(
b_scales
.
dtype
()
==
torch
::
kFloat32
);
...
...
csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu
View file @
afd0da21
#include <cudaTypedefs.h>
#include "c3x/scaled_mm_kernels.hpp"
#if defined CUDA_VERSION && CUDA_VERSION >= 12000
#include "scaled_mm_c3x_sm90_fp8_dispatch.cuh"
#include "scaled_mm_c3x_sm90_int8_dispatch.cuh"
#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp"
using
namespace
vllm
;
#include "core/math.hpp"
/*
This file defines quantized GEMM operations using the CUTLASS 3.x API, for
NVIDIA GPUs with sm90a (Hopper) or later.
*/
template
<
template
<
typename
,
typename
,
typename
>
typename
Epilogue
,
typename
...
EpilogueArgs
>
void
cutlass_scaled_mm_sm90_epilogue
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
EpilogueArgs
&&
...
epilogue_args
)
{
if
(
a
.
dtype
()
==
torch
::
kInt8
)
{
TORCH_CHECK
(
b
.
dtype
()
==
torch
::
kInt8
);
if
(
out
.
dtype
()
==
torch
::
kBFloat16
)
{
return
cutlass_gemm_sm90_int8_dispatch
<
int8_t
,
cutlass
::
bfloat16_t
,
Epilogue
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
epilogue_args
)...);
}
else
{
TORCH_CHECK
(
out
.
dtype
()
==
torch
::
kFloat16
);
return
cutlass_gemm_sm90_int8_dispatch
<
int8_t
,
cutlass
::
half_t
,
Epilogue
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
epilogue_args
)...);
}
}
else
{
TORCH_CHECK
(
a
.
dtype
()
==
torch
::
kFloat8_e4m3fn
);
TORCH_CHECK
(
b
.
dtype
()
==
torch
::
kFloat8_e4m3fn
);
if
(
out
.
dtype
()
==
torch
::
kBFloat16
)
{
return
cutlass_gemm_sm90_fp8_dispatch
<
cutlass
::
float_e4m3_t
,
cutlass
::
bfloat16_t
,
Epilogue
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
epilogue_args
)...);
}
else
{
TORCH_CHECK
(
out
.
dtype
()
==
torch
::
kFloat16
);
return
cutlass_gemm_sm90_fp8_dispatch
<
cutlass
::
float_e4m3_t
,
cutlass
::
half_t
,
Epilogue
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
epilogue_args
)...);
}
}
}
void
cutlass_scaled_mm_sm90
(
torch
::
Tensor
&
c
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
,
c10
::
optional
<
torch
::
Tensor
>
const
&
bias
)
{
std
::
optional
<
torch
::
Tensor
>
const
&
bias
)
{
TORCH_CHECK
(
a_scales
.
dtype
()
==
torch
::
kFloat32
);
TORCH_CHECK
(
b_scales
.
dtype
()
==
torch
::
kFloat32
);
if
(
bias
)
{
TORCH_CHECK
(
bias
->
dtype
()
==
c
.
dtype
(),
"currently bias dtype must match output dtype "
,
c
.
dtype
());
return
cutlass_scaled_mm_sm90_epilogue
<
c3x
::
ScaledEpilogueBias
>
(
c
,
a
,
b
,
a_scales
,
b_scales
,
*
bias
);
using
GroupShape
=
std
::
array
<
int64_t
,
2
>
;
int
M
=
a
.
size
(
0
),
N
=
b
.
size
(
1
),
K
=
a
.
size
(
1
);
GroupShape
a_scale_group_shape
=
[
&
,
&
s
=
a_scales
]()
->
GroupShape
{
if
(
s
.
numel
()
==
1
)
return
{
M
,
K
};
// tensor-wise
if
(
s
.
dim
()
==
2
)
return
{
ceil_div
(
a
.
size
(
0
),
s
.
size
(
0
)),
ceil_div
(
a
.
size
(
1
),
s
.
size
(
1
))};
TORCH_CHECK
(
false
,
"Unsupported scale shape for scale_a"
);
}();
GroupShape
b_scale_group_shape
=
[
&
,
&
s
=
b_scales
]()
->
GroupShape
{
if
(
s
.
numel
()
==
1
)
return
{
K
,
N
};
// tensor-wise
if
(
s
.
dim
()
==
2
)
return
{
ceil_div
(
b
.
size
(
0
),
s
.
size
(
0
)),
ceil_div
(
b
.
size
(
1
),
s
.
size
(
1
))};
TORCH_CHECK
(
false
,
"Unsupported scale shape for scale_b"
);
}();
if
((
a_scale_group_shape
==
GroupShape
{
M
,
K
}
||
a_scale_group_shape
==
GroupShape
{
1
,
K
})
&&
(
b_scale_group_shape
==
GroupShape
{
K
,
N
}
||
b_scale_group_shape
==
GroupShape
{
K
,
1
}))
{
// "standard per-tensor/per-token/per-channel" scaling
TORCH_CHECK
(
a_scales
.
is_contiguous
()
&&
b_scales
.
is_contiguous
());
if
(
a
.
dtype
()
==
torch
::
kFloat8_e4m3fn
)
{
vllm
::
cutlass_scaled_mm_sm90_fp8
(
c
,
a
,
b
,
a_scales
,
b_scales
,
bias
);
}
else
{
TORCH_CHECK
(
a
.
dtype
()
==
torch
::
kInt8
);
vllm
::
cutlass_scaled_mm_sm90_int8
(
c
,
a
,
b
,
a_scales
,
b_scales
,
bias
);
}
}
else
if
(
a_scale_group_shape
==
GroupShape
{
1
,
128
}
&&
b_scale_group_shape
==
GroupShape
{
128
,
128
})
{
// 1x128 per-token group scales for activations
// 128x128 blockwise scales for weights
TORCH_CHECK
(
a
.
dtype
()
==
torch
::
kFloat8_e4m3fn
&&
b
.
dtype
()
==
torch
::
kFloat8_e4m3fn
,
"Currently only FP8 is supported for A group shape 1x128 and "
"B group shape 128x128"
);
TORCH_CHECK
(
!
bias
,
"Bias not yet supported blockwise scaled_mm"
);
vllm
::
cutlass_scaled_mm_blockwise_sm90_fp8
(
c
,
a
,
b
,
a_scales
,
b_scales
);
}
else
{
return
cutlass_scaled_mm_sm90_epilogue
<
c3x
::
ScaledEpilogue
>
(
c
,
a
,
b
,
a_scales
,
b_scales
);
TORCH_CHECK
(
false
,
"Unsupported scale group shapes for CUTLASS 3.x GEMM.
\n
"
"a_scale_group_shape must be [1, 128], got: ["
,
a_scale_group_shape
[
0
],
", "
,
a_scale_group_shape
[
1
],
"]
\n
"
"b_scale_group_shape must be [128, 128], got: ["
,
b_scale_group_shape
[
0
],
", "
,
b_scale_group_shape
[
1
],
"]"
);
}
}
...
...
@@ -70,18 +73,11 @@ void cutlass_scaled_mm_azp_sm90(torch::Tensor& out, torch::Tensor const& a,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
,
torch
::
Tensor
const
&
azp_adj
,
c10
::
optional
<
torch
::
Tensor
>
const
&
azp
,
c10
::
optional
<
torch
::
Tensor
>
const
&
bias
)
{
std
::
optional
<
torch
::
Tensor
>
const
&
azp
,
std
::
optional
<
torch
::
Tensor
>
const
&
bias
)
{
TORCH_CHECK
(
a_scales
.
dtype
()
==
torch
::
kFloat32
);
TORCH_CHECK
(
b_scales
.
dtype
()
==
torch
::
kFloat32
);
if
(
azp
)
{
return
cutlass_scaled_mm_sm90_epilogue
<
c3x
::
ScaledEpilogueBiasAzpToken
>
(
out
,
a
,
b
,
a_scales
,
b_scales
,
azp_adj
,
*
azp
,
bias
);
}
else
{
return
cutlass_scaled_mm_sm90_epilogue
<
c3x
::
ScaledEpilogueBiasAzp
>
(
out
,
a
,
b
,
a_scales
,
b_scales
,
azp_adj
,
bias
);
}
vllm
::
cutlass_scaled_mm_azp_sm90_int8
(
out
,
a
,
b
,
a_scales
,
b_scales
,
azp_adj
,
azp
,
bias
);
}
#endif
csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu
View file @
afd0da21
...
...
@@ -9,26 +9,26 @@ void cutlass_scaled_mm_sm75(torch::Tensor& c, torch::Tensor const& a,
torch
::
Tensor
const
&
b
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
,
c10
::
optional
<
torch
::
Tensor
>
const
&
bias
);
std
::
optional
<
torch
::
Tensor
>
const
&
bias
);
void
cutlass_scaled_mm_sm80
(
torch
::
Tensor
&
c
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
,
c10
::
optional
<
torch
::
Tensor
>
const
&
bias
);
std
::
optional
<
torch
::
Tensor
>
const
&
bias
);
void
cutlass_scaled_mm_sm89
(
torch
::
Tensor
&
c
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
,
c10
::
optional
<
torch
::
Tensor
>
const
&
bias
);
std
::
optional
<
torch
::
Tensor
>
const
&
bias
);
#if defined ENABLE_SCALED_MM_C3X && ENABLE_SCALED_MM_C3X
void
cutlass_scaled_mm_sm90
(
torch
::
Tensor
&
c
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
,
c10
::
optional
<
torch
::
Tensor
>
const
&
bias
);
std
::
optional
<
torch
::
Tensor
>
const
&
bias
);
#endif
void
cutlass_scaled_mm_azp_sm75
(
torch
::
Tensor
&
c
,
torch
::
Tensor
const
&
a
,
...
...
@@ -36,24 +36,24 @@ void cutlass_scaled_mm_azp_sm75(torch::Tensor& c, torch::Tensor const& a,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
,
torch
::
Tensor
const
&
azp_adj
,
c10
::
optional
<
torch
::
Tensor
>
const
&
azp
,
c10
::
optional
<
torch
::
Tensor
>
const
&
bias
);
std
::
optional
<
torch
::
Tensor
>
const
&
azp
,
std
::
optional
<
torch
::
Tensor
>
const
&
bias
);
void
cutlass_scaled_mm_azp_sm80
(
torch
::
Tensor
&
c
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
,
torch
::
Tensor
const
&
azp_adj
,
c10
::
optional
<
torch
::
Tensor
>
const
&
azp
,
c10
::
optional
<
torch
::
Tensor
>
const
&
bias
);
std
::
optional
<
torch
::
Tensor
>
const
&
azp
,
std
::
optional
<
torch
::
Tensor
>
const
&
bias
);
void
cutlass_scaled_mm_azp_sm89
(
torch
::
Tensor
&
c
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
,
torch
::
Tensor
const
&
azp_adj
,
c10
::
optional
<
torch
::
Tensor
>
const
&
azp
,
c10
::
optional
<
torch
::
Tensor
>
const
&
bias
);
std
::
optional
<
torch
::
Tensor
>
const
&
azp
,
std
::
optional
<
torch
::
Tensor
>
const
&
bias
);
#if defined CUDA_VERSION && CUDA_VERSION >= 12000
void
cutlass_scaled_mm_azp_sm90
(
torch
::
Tensor
&
c
,
torch
::
Tensor
const
&
a
,
...
...
@@ -61,8 +61,8 @@ void cutlass_scaled_mm_azp_sm90(torch::Tensor& c, torch::Tensor const& a,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
,
torch
::
Tensor
const
&
azp_adj
,
c10
::
optional
<
torch
::
Tensor
>
const
&
azp
,
c10
::
optional
<
torch
::
Tensor
>
const
&
bias
);
std
::
optional
<
torch
::
Tensor
>
const
&
azp
,
std
::
optional
<
torch
::
Tensor
>
const
&
bias
);
#endif
bool
cutlass_scaled_mm_supports_fp8
(
int64_t
cuda_device_capability
)
{
...
...
@@ -81,23 +81,33 @@ bool cutlass_scaled_mm_supports_fp8(int64_t cuda_device_capability) {
return
false
;
}
bool
cutlass_scaled_mm_supports_block_fp8
(
int64_t
cuda_device_capability
)
{
// CUTLASS block-quantized FP8 kernels need at least CUDA 12.0
// and at least SM90 (Hopper)
#if defined CUDA_VERSION
if
(
cuda_device_capability
>=
90
)
{
return
CUDA_VERSION
>=
12000
;
}
#endif
return
false
;
}
void
cutlass_scaled_mm
(
torch
::
Tensor
&
c
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
,
c10
::
optional
<
torch
::
Tensor
>
const
&
bias
)
{
std
::
optional
<
torch
::
Tensor
>
const
&
bias
)
{
// Checks for conformality
TORCH_CHECK
(
a
.
dim
()
==
2
&&
b
.
dim
()
==
2
&&
c
.
dim
()
==
2
);
TORCH_CHECK
(
c
.
size
(
0
)
==
a
.
size
(
0
)
&&
a
.
size
(
1
)
==
b
.
size
(
0
)
&&
b
.
size
(
1
)
==
c
.
size
(
1
));
TORCH_CHECK
(
a_scales
.
numel
()
==
1
||
a_scales
.
numel
()
==
a
.
size
(
0
));
TORCH_CHECK
(
b_scales
.
numel
()
==
1
||
b_scales
.
numel
()
==
b
.
size
(
1
));
// Check for strides and alignment
TORCH_CHECK
(
a
.
stride
(
1
)
==
1
&&
c
.
stride
(
1
)
==
1
);
// Row-major
TORCH_CHECK
(
b
.
stride
(
0
)
==
1
);
// Column-major
TORCH_CHECK
(
c
.
stride
(
0
)
%
16
==
0
&&
b
.
stride
(
1
)
%
16
==
0
);
// 16 Byte Alignment
TORCH_CHECK
(
a_scales
.
is_contiguous
()
&&
b_scales
.
is_contiguous
());
if
(
bias
)
{
TORCH_CHECK
(
bias
->
numel
()
==
b
.
size
(
1
)
&&
bias
->
is_contiguous
()
&&
...
...
@@ -148,8 +158,8 @@ void cutlass_scaled_mm_azp(torch::Tensor& c, torch::Tensor const& a,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
,
torch
::
Tensor
const
&
azp_adj
,
c10
::
optional
<
torch
::
Tensor
>
const
&
azp
,
c10
::
optional
<
torch
::
Tensor
>
const
&
bias
)
{
std
::
optional
<
torch
::
Tensor
>
const
&
azp
,
std
::
optional
<
torch
::
Tensor
>
const
&
bias
)
{
// Checks for conformality
TORCH_CHECK
(
a
.
dim
()
==
2
&&
b
.
dim
()
==
2
&&
c
.
dim
()
==
2
);
TORCH_CHECK
(
c
.
size
(
0
)
==
a
.
size
(
0
)
&&
a
.
size
(
1
)
==
b
.
size
(
0
)
&&
...
...
@@ -215,4 +225,4 @@ void cutlass_scaled_mm_azp(torch::Tensor& c, torch::Tensor const& a,
"No compiled cutlass_scaled_mm_azp for a compute capability less than "
"CUDA device capability: "
,
version_num
);
}
\ No newline at end of file
}
csrc/quantization/gptq_marlin/gptq_marlin.cu
View file @
afd0da21
...
...
@@ -173,8 +173,8 @@ dequant<half, vllm::kU4B8.id()>(int q) {
const
int
HI
=
0x00f000f0
;
const
int
EX
=
0x64006400
;
// Guarantee that the `(a & b) | c` operations are LOP3s.
int
lo
=
lop3
<
(
0xf0
&
0xcc
)
|
0xaa
>
(
q
,
LO
,
EX
);
int
hi
=
lop3
<
(
0xf0
&
0xcc
)
|
0xaa
>
(
q
,
HI
,
EX
);
int
lo
=
lop3
<
(
0xf0
&
0xcc
)
|
0xaa
>
(
q
,
LO
,
EX
);
int
hi
=
lop3
<
(
0xf0
&
0xcc
)
|
0xaa
>
(
q
,
HI
,
EX
);
// We want signed int4 outputs, hence we fuse the `-8` symmetric zero point
// directly into `SUB` and `ADD`.
const
int
SUB
=
0x64086408
;
...
...
@@ -197,9 +197,9 @@ dequant<nv_bfloat16, vllm::kU4B8.id()>(int q) {
// Guarantee that the `(a & b) | c` operations are LOP3s.
int
lo
=
lop3
<
(
0xf0
&
0xcc
)
|
0xaa
>
(
q
,
MASK
,
EX
);
int
lo
=
lop3
<
(
0xf0
&
0xcc
)
|
0xaa
>
(
q
,
MASK
,
EX
);
q
>>=
4
;
int
hi
=
lop3
<
(
0xf0
&
0xcc
)
|
0xaa
>
(
q
,
MASK
,
EX
);
int
hi
=
lop3
<
(
0xf0
&
0xcc
)
|
0xaa
>
(
q
,
MASK
,
EX
);
typename
ScalarType
<
nv_bfloat16
>::
FragB
frag_b
;
static
constexpr
uint32_t
MUL
=
0x3F803F80
;
...
...
@@ -221,8 +221,8 @@ dequant<half, vllm::kU4.id()>(int q) {
const
int
HI
=
0x00f000f0
;
const
int
EX
=
0x64006400
;
// Guarantee that the `(a & b) | c` operations are LOP3s.
int
lo
=
lop3
<
(
0xf0
&
0xcc
)
|
0xaa
>
(
q
,
LO
,
EX
);
int
hi
=
lop3
<
(
0xf0
&
0xcc
)
|
0xaa
>
(
q
,
HI
,
EX
);
int
lo
=
lop3
<
(
0xf0
&
0xcc
)
|
0xaa
>
(
q
,
LO
,
EX
);
int
hi
=
lop3
<
(
0xf0
&
0xcc
)
|
0xaa
>
(
q
,
HI
,
EX
);
const
int
SUB
=
0x64006400
;
const
int
MUL
=
0x2c002c00
;
...
...
@@ -244,9 +244,9 @@ dequant<nv_bfloat16, vllm::kU4.id()>(int q) {
// Guarantee that the `(a & b) | c` operations are LOP3s.
int
lo
=
lop3
<
(
0xf0
&
0xcc
)
|
0xaa
>
(
q
,
MASK
,
EX
);
int
lo
=
lop3
<
(
0xf0
&
0xcc
)
|
0xaa
>
(
q
,
MASK
,
EX
);
q
>>=
4
;
int
hi
=
lop3
<
(
0xf0
&
0xcc
)
|
0xaa
>
(
q
,
MASK
,
EX
);
int
hi
=
lop3
<
(
0xf0
&
0xcc
)
|
0xaa
>
(
q
,
MASK
,
EX
);
typename
ScalarType
<
nv_bfloat16
>::
FragB
frag_b
;
static
constexpr
uint32_t
MUL
=
0x3F803F80
;
...
...
@@ -834,6 +834,7 @@ __global__ void Marlin(
int4
*
sh_g_idx
=
sh_b
+
(
stages
*
b_sh_stage
);
int4
*
sh_zp
=
sh_g_idx
+
(
stages
*
g_idx_stage
);
int4
*
sh_s
=
sh_zp
+
(
stages
*
zp_sh_stage
);
int4
*
sh_red
=
sh_s
+
(
stages
*
s_sh_stage
);
// Register storage for double buffer of shared memory reads.
FragA
frag_a
[
2
][
thread_m_blocks
];
...
...
@@ -932,11 +933,11 @@ __global__ void Marlin(
int4
*
sh_s_stage
=
sh_s
+
s_sh_stage
*
pipe
;
if
constexpr
(
group_blocks
>=
thread_k_blocks
)
{
if
(
s_sh_wr_pred
)
{
cp_async4
(
&
sh_s_stage
[
s_sh_wr
],
&
scales_ptr
[
s_gl_rd
]);
}
// Only fetch scales if this tile starts a new group
if
(
pipe
%
(
group_blocks
/
thread_k_blocks
)
==
0
)
{
if
(
s_sh_wr_pred
)
{
cp_async4
(
&
sh_s_stage
[
s_sh_wr
],
&
scales_ptr
[
s_gl_rd
]);
}
if
((
pipe
+
1
)
%
(
group_blocks
/
thread_k_blocks
)
==
0
)
{
s_gl_rd
+=
s_gl_rd_delta
;
}
}
else
{
...
...
@@ -1038,9 +1039,7 @@ __global__ void Marlin(
// No act-order case
if
constexpr
(
group_blocks
!=
-
1
)
{
if
constexpr
(
group_blocks
>=
thread_k_blocks
)
{
int4
*
sh_s_stage
=
sh_s
+
s_sh_stage
*
((
group_blocks
/
thread_k_blocks
)
*
(
pipe
/
(
group_blocks
/
thread_k_blocks
)));
int4
*
sh_s_stage
=
sh_s
+
s_sh_stage
*
pipe
;
reinterpret_cast
<
int4
*>
(
&
frag_s
[
k
%
2
])[
0
]
=
sh_s_stage
[
s_sh_rd
];
}
else
{
int
warp_id
=
threadIdx
.
x
/
32
;
...
...
@@ -1339,15 +1338,15 @@ __global__ void Marlin(
int
red_sh_wr
=
red_sh_delta
*
j
+
(
red_sh_rd
-
red_sh_stride
*
i
);
if
(
i
<
red_off
)
{
float
*
c_rd
=
reinterpret_cast
<
float
*>
(
&
sh
[
red_sh_delta
*
j
+
red_sh_rd
]);
float
*
c_wr
=
reinterpret_cast
<
float
*>
(
&
sh
[
red_sh_wr
]);
float
*
c_rd
=
reinterpret_cast
<
float
*>
(
&
sh_red
[
red_sh_delta
*
j
+
red_sh_rd
]);
float
*
c_wr
=
reinterpret_cast
<
float
*>
(
&
sh
_red
[
red_sh_wr
]);
#pragma unroll
for
(
int
k
=
0
;
k
<
4
;
k
++
)
reinterpret_cast
<
FragC
*>
(
frag_c
)[
4
*
2
*
m_block
+
j
][
k
]
+=
c_rd
[
k
]
+
c_wr
[
k
];
}
sh
[
red_sh_wr
]
=
sh
_red
[
red_sh_wr
]
=
reinterpret_cast
<
int4
*>
(
&
frag_c
)[
4
*
2
*
m_block
+
j
];
}
}
...
...
@@ -1357,7 +1356,7 @@ __global__ void Marlin(
#pragma unroll
for
(
int
i
=
0
;
i
<
4
*
2
;
i
++
)
{
float
*
c_rd
=
reinterpret_cast
<
float
*>
(
&
sh
[
red_sh_delta
*
i
+
red_sh_rd
]);
reinterpret_cast
<
float
*>
(
&
sh
_red
[
red_sh_delta
*
i
+
red_sh_rd
]);
#pragma unroll
for
(
int
j
=
0
;
j
<
4
;
j
++
)
reinterpret_cast
<
FragC
*>
(
frag_c
)[
4
*
2
*
m_block
+
i
][
j
]
+=
...
...
@@ -1397,7 +1396,7 @@ __global__ void Marlin(
#pragma unroll
for
(
int
i
=
0
;
i
<
thread_m_blocks
*
4
;
i
++
)
{
cp_async4_pred
(
&
sh
[
c_sh_wr
+
c_sh_wr_delta
*
i
],
&
sh
_red
[
c_sh_wr
+
c_sh_wr_delta
*
i
],
&
C
[
c_gl_wr
+
c_gl_wr_delta_o
*
(
i
/
2
)
+
c_gl_wr_delta_i
*
(
i
%
2
)],
i
<
(
thread_m_blocks
-
1
)
*
4
||
8
*
(
i
/
2
)
+
row
<
prob_m
);
...
...
@@ -1410,7 +1409,7 @@ __global__ void Marlin(
for
(
int
i
=
0
;
i
<
thread_m_blocks
*
4
;
i
++
)
{
if
(
i
<
(
thread_m_blocks
-
1
)
*
4
||
8
*
(
i
/
2
)
+
row
<
prob_m
)
{
if
(
!
first
)
{
int4
c_red
=
sh
[
c_sh_wr
+
i
*
c_sh_wr_delta
];
int4
c_red
=
sh
_red
[
c_sh_wr
+
i
*
c_sh_wr_delta
];
#pragma unroll
for
(
int
j
=
0
;
j
<
2
*
4
;
j
++
)
{
reinterpret_cast
<
float
*>
(
...
...
@@ -1461,10 +1460,10 @@ __global__ void Marlin(
float
*
frag_c_ptr
=
reinterpret_cast
<
float
*>
(
&
frag_c
);
#pragma unroll
for
(
int
k
=
0
;
k
<
th_size
;
k
++
)
{
sh
[
threadIdx
.
x
]
=
sh
_red
[
threadIdx
.
x
]
=
C_tmp
[
c_cur_offset
+
active_threads
*
k
+
threadIdx
.
x
];
float
*
sh_c_ptr
=
reinterpret_cast
<
float
*>
(
&
sh
[
threadIdx
.
x
]);
float
*
sh_c_ptr
=
reinterpret_cast
<
float
*>
(
&
sh
_red
[
threadIdx
.
x
]);
#pragma unroll
for
(
int
f
=
0
;
f
<
4
;
f
++
)
{
frag_c_ptr
[
k
*
4
+
f
]
+=
sh_c_ptr
[
f
];
...
...
@@ -1515,7 +1514,7 @@ __global__ void Marlin(
res
=
__hmul2
(
res
,
s
[
0
]);
}
((
scalar_t2
*
)
sh
)[
idx
]
=
res
;
((
scalar_t2
*
)
sh
_red
)[
idx
]
=
res
;
};
if
(
threadIdx
.
x
/
32
<
thread_n_blocks
/
4
)
{
...
...
@@ -1543,7 +1542,7 @@ __global__ void Marlin(
i
<
div_ceil
(
16
*
thread_m_blocks
,
threads
/
(
2
*
thread_n_blocks
));
i
++
)
{
if
(
c_gl_wr
<
c_gl_wr_end
)
{
C
[
c_gl_wr
]
=
sh
[
c_sh_rd
];
C
[
c_gl_wr
]
=
sh
_red
[
c_sh_rd
];
c_gl_wr
+=
c_gl_wr_delta
;
c_sh_rd
+=
c_sh_rd_delta
;
}
...
...
@@ -1865,9 +1864,12 @@ bool is_valid_cache_size(thread_config_t const& th_config, int max_m_blocks,
float
pipe_size
=
(
a_size
+
b_size
)
*
pipe_stages
;
float
reduce_size
=
max
(
th_config
.
num_threads
*
32
*
4
,
(
tb_n
/
64
)
*
32
*
(
tb_max_m
/
16
)
*
4
*
2
*
4
*
2
);
TORCH_CHECK
(
max_shared_mem
/
2
>
scales_cache_size
);
// Sanity
return
pipe_size
<
0.95
f
*
(
max_shared_mem
-
scales_cache_size
);
return
pipe_size
+
reduce_size
<
0.95
f
*
(
max_shared_mem
-
scales_cache_size
);
}
bool
is_valid_config
(
thread_config_t
const
&
th_config
,
int
max_m_blocks
,
...
...
csrc/quantization/machete/generate.py
View file @
afd0da21
...
...
@@ -63,7 +63,7 @@ torch::Tensor mm_dispatch_{{type_sig}}(MMArgs args) {
static inline std::optional<at::ScalarType> maybe_scalartype(
c10
::optional<at::Tensor> const& t) {
std
::optional<at::Tensor> const& t) {
if (!t) {
return std::nullopt;
} else {
...
...
@@ -189,7 +189,7 @@ using Kernel_{{type_sig}} = MacheteKernelTemplate<
{{DataTypeTag[t.b_group_zeropoint]}}, // GroupZeroT
{{DataTypeTag[t.b_channel_scale]}}, // ChannelScaleT
{{DataTypeTag[t.a_token_scale]}}, // TokenScaleT
cutlass::gemm::KernelTmaWarpSpecializedCooperative
MixedInput
,
cutlass::gemm::KernelTmaWarpSpecializedCooperative,
Sch>;
{% for sch in schs %}
...
...
@@ -223,7 +223,7 @@ torch::Tensor prepack_B_dispatch(PrepackBArgs args) {
{{DataTypeTag[t.convert]}}, // ElementConvert
{{DataTypeTag[t.accumulator]}}, // Accumulator
cutlass::layout::ColumnMajor,
cutlass::gemm::KernelTmaWarpSpecializedCooperative
MixedInput
>
cutlass::gemm::KernelTmaWarpSpecializedCooperative>
>(args.B);
}
{%- endfor %}
...
...
@@ -239,7 +239,7 @@ torch::Tensor prepack_B_dispatch(PrepackBArgs args) {
}; // namespace machete
"""
TmaMI
=
MixedInputKernelScheduleType
.
TmaWarpSpecializedCooperative
MixedInput
TmaMI
=
MixedInputKernelScheduleType
.
TmaWarpSpecializedCooperative
TmaCoop
=
EpilogueScheduleType
.
TmaWarpSpecializedCooperative
...
...
@@ -300,7 +300,7 @@ def generate_sch_sig(schedule_config: ScheduleConfig) -> str:
# mostly unique shorter sch_sig
def
generate_terse_sch_sig
(
schedule_config
:
ScheduleConfig
)
->
str
:
kernel_terse_names_replace
=
{
"KernelTmaWarpSpecializedCooperative
MixedInput_
"
:
"TmaMI_"
,
"KernelTmaWarpSpecializedCooperative"
:
"TmaMI_"
,
"TmaWarpSpecializedCooperative_"
:
"TmaCoop_"
,
"StreamKScheduler"
:
"streamK"
,
}
...
...
csrc/quantization/machete/machete_collective_builder.cuh
View file @
afd0da21
...
...
@@ -18,16 +18,14 @@ struct VLLMCollectiveBuilder<
ElementAccumulator
,
TileShape_MNK
,
ClusterShape_MNK
,
StageCountType
,
KernelScheduleType
,
cute
::
enable_if_t
<
(
cute
::
is_same_v
<
KernelScheduleType
,
KernelTmaWarpSpecialized
>
||
cute
::
is_same_v
<
KernelScheduleType
,
KernelTmaWarpSpecializedPingpong
>
||
cute
::
is_same_v
<
KernelScheduleType
,
KernelTmaWarpSpecializedMixedInput
>
||
cute
::
is_same_v
<
KernelScheduleType
,
KernelTmaWarpSpecializedPingpongMixedInput
>
||
cute
::
is_same_v
<
KernelScheduleType
,
KernelTmaWarpSpecializedCooperativeMixedInput
>
)
>>
{
KernelTmaWarpSpecializedCooperative
>
)
>>
{
using
CollectiveOp
=
machete
::
MacheteCollectiveMma
<
ElementPairA_
,
GmemLayoutA_
,
AlignmentA
,
ElementPairB_
,
GmemLayoutB_
,
AlignmentB
,
ElementAccumulator
,
TileShape_MNK
,
ClusterShape_MNK
,
StageCountType
,
KernelScheduleType
>
;
};
};
// namespace cutlass::gemm::collective
\ No newline at end of file
};
// namespace cutlass::gemm::collective
csrc/quantization/machete/machete_mainloop.cuh
View file @
afd0da21
...
...
@@ -66,13 +66,11 @@ struct MacheteCollectiveMma {
using
Schedule
=
KernelScheduleType
;
static_assert
(
cute
::
is_same_v
<
Schedule
,
KernelTmaWarpSpecialized
>
||
cute
::
is_same_v
<
Schedule
,
KernelTmaWarpSpecializedMixedInput
>
||
cute
::
is_same_v
<
Schedule
,
KernelTmaWarpSpecialized
>
||
cute
::
is_same_v
<
Schedule
,
KernelTmaWarpSpecializedPingpong
>
||
cute
::
is_same_v
<
Schedule
,
KernelTmaWarpSpecializedPingpong
>
||
cute
::
is_same_v
<
Schedule
,
KernelTmaWarpSpecializedPingpongMixedInput
>
||
cute
::
is_same_v
<
Schedule
,
KernelTmaWarpSpecializedCooperative
>
||
cute
::
is_same_v
<
Schedule
,
KernelTmaWarpSpecializedCooperativeMixedInput
>
,
cute
::
is_same_v
<
Schedule
,
KernelTmaWarpSpecializedCooperative
>
,
"KernelSchedule must be one of the warp specialized policies"
);
public:
...
...
@@ -113,8 +111,7 @@ struct MacheteCollectiveMma {
// For coop schedules we have two warp groups cooperatively issuing wgmma
// instructions so we use 2 atoms along the M dim (one for each warpgroup)
using
AtomLayoutMNK
=
cute
::
conditional_t
<
cute
::
is_same_v
<
KernelScheduleType
,
KernelTmaWarpSpecializedCooperativeMixedInput
>
,
cute
::
is_same_v
<
KernelScheduleType
,
KernelTmaWarpSpecializedCooperative
>
,
Layout
<
Shape
<
_2
,
_1
,
_1
>>
,
Layout
<
Shape
<
_1
,
_1
,
_1
>>>
;
using
TiledMma
=
decltype
(
cute
::
make_tiled_mma
(
...
...
@@ -275,6 +272,10 @@ struct MacheteCollectiveMma {
using
PipelineState
=
cutlass
::
PipelineState
<
DispatchPolicy
::
Stages
>
;
using
PipelineParams
=
typename
MainloopPipeline
::
Params
;
// One threads per CTA are producers (1 for operand tile)
static
constexpr
int
NumProducerThreadEvents
=
1
;
using
ScaleTileShape
=
decltype
(
make_shape
(
shape
<
0
>
(
TileShape
{}),
shape
<
1
>
(
SmemLayoutAtomScale
{})));
...
...
csrc/quantization/machete/machete_mm_kernel.cuh
View file @
afd0da21
...
...
@@ -183,11 +183,11 @@ struct MacheteKernelTemplate {
torch
::
Tensor
const
&
A
,
// MxK matrix
torch
::
Tensor
const
&
B
,
// KxN prepacked matrix
torch
::
Tensor
&
D
,
// MxN matrix
c10
::
optional
<
torch
::
Tensor
>
const
&
maybe_g_scales
,
// scale_KxN matrix
c10
::
optional
<
torch
::
Tensor
>
const
&
maybe_g_zeros
,
// scale_KxN matrix
c10
::
optional
<
int64_t
>
maybe_group_size
,
c10
::
optional
<
torch
::
Tensor
>
const
&
maybe_ch_scales
,
// len N vector
c10
::
optional
<
torch
::
Tensor
>
const
&
maybe_tok_scales
)
// len M vector
std
::
optional
<
torch
::
Tensor
>
const
&
maybe_g_scales
,
// scale_KxN matrix
std
::
optional
<
torch
::
Tensor
>
const
&
maybe_g_zeros
,
// scale_KxN matrix
std
::
optional
<
int64_t
>
maybe_group_size
,
std
::
optional
<
torch
::
Tensor
>
const
&
maybe_ch_scales
,
// len N vector
std
::
optional
<
torch
::
Tensor
>
const
&
maybe_tok_scales
)
// len M vector
{
static_assert
(
!
with_group_zeropoints
||
with_group_scales
);
...
...
csrc/quantization/machete/machete_mm_launcher.cuh
View file @
afd0da21
...
...
@@ -13,23 +13,23 @@ struct MMArgs {
torch
::
Tensor
const
&
A
;
torch
::
Tensor
const
&
B
;
vllm
::
ScalarType
const
&
b_type
;
c10
::
optional
<
at
::
ScalarType
>
const
&
maybe_out_type
;
c10
::
optional
<
torch
::
Tensor
>
const
&
maybe_group_scales
;
c10
::
optional
<
torch
::
Tensor
>
const
&
maybe_group_zeros
;
c10
::
optional
<
int64_t
>
maybe_group_size
;
c10
::
optional
<
torch
::
Tensor
>
const
&
maybe_channel_scales
;
c10
::
optional
<
torch
::
Tensor
>
const
&
maybe_token_scales
;
c10
::
optional
<
std
::
string
>
maybe_schedule
;
std
::
optional
<
at
::
ScalarType
>
const
&
maybe_out_type
;
std
::
optional
<
torch
::
Tensor
>
const
&
maybe_group_scales
;
std
::
optional
<
torch
::
Tensor
>
const
&
maybe_group_zeros
;
std
::
optional
<
int64_t
>
maybe_group_size
;
std
::
optional
<
torch
::
Tensor
>
const
&
maybe_channel_scales
;
std
::
optional
<
torch
::
Tensor
>
const
&
maybe_token_scales
;
std
::
optional
<
std
::
string
>
maybe_schedule
;
};
struct
SupportedSchedulesArgs
{
at
::
ScalarType
a_type
;
vllm
::
ScalarType
b_type
;
c10
::
optional
<
at
::
ScalarType
>
maybe_group_scales_type
;
c10
::
optional
<
at
::
ScalarType
>
maybe_group_zeros_type
;
c10
::
optional
<
at
::
ScalarType
>
maybe_channel_scales_type
;
c10
::
optional
<
at
::
ScalarType
>
maybe_token_scales_type
;
c10
::
optional
<
at
::
ScalarType
>
maybe_out_type
;
std
::
optional
<
at
::
ScalarType
>
maybe_group_scales_type
;
std
::
optional
<
at
::
ScalarType
>
maybe_group_zeros_type
;
std
::
optional
<
at
::
ScalarType
>
maybe_channel_scales_type
;
std
::
optional
<
at
::
ScalarType
>
maybe_token_scales_type
;
std
::
optional
<
at
::
ScalarType
>
maybe_out_type
;
};
torch
::
Tensor
mm_dispatch
(
MMArgs
args
);
...
...
csrc/quantization/machete/machete_prepack_launcher.cuh
View file @
afd0da21
...
...
@@ -10,7 +10,7 @@ struct PrepackBArgs {
torch
::
Tensor
const
&
B
;
at
::
ScalarType
a_type
;
vllm
::
ScalarType
b_type
;
c10
::
optional
<
at
::
ScalarType
>
maybe_group_scales_type
;
std
::
optional
<
at
::
ScalarType
>
maybe_group_scales_type
;
};
template
<
typename
PrepackedLayoutB
>
...
...
csrc/quantization/machete/machete_prepacked_layout.cuh
View file @
afd0da21
...
...
@@ -98,8 +98,7 @@ struct PrepackedLayoutBTemplate {
// For coop schedules we have two warp groups cooperatively issuing wgmma
// instructions so we use 2 atoms along the M dim (one for each warpgroup)
using
AtomLayoutMNK
=
cute
::
conditional_t
<
cute
::
is_same_v
<
KernelSchedule
,
KernelTmaWarpSpecializedCooperativeMixedInput
>
,
cute
::
is_same_v
<
KernelSchedule
,
KernelTmaWarpSpecializedCooperative
>
,
Layout
<
Shape
<
_2
,
_1
,
_1
>>
,
Layout
<
Shape
<
_1
,
_1
,
_1
>>>
;
using
TiledMma
=
decltype
(
cute
::
make_tiled_mma
(
...
...
@@ -247,4 +246,4 @@ struct PrepackedLayoutBTemplate {
}
};
};
// namespace machete
\ No newline at end of file
};
// namespace machete
csrc/quantization/machete/machete_pytorch.cu
View file @
afd0da21
...
...
@@ -10,11 +10,11 @@ using namespace vllm;
std
::
vector
<
std
::
string
>
supported_schedules
(
at
::
ScalarType
a_type
,
int64_t
b_type_id
,
c10
::
optional
<
at
::
ScalarType
>
maybe_group_scales_type
,
c10
::
optional
<
at
::
ScalarType
>
maybe_group_zeros_type
,
c10
::
optional
<
at
::
ScalarType
>
maybe_channel_scales_type
,
c10
::
optional
<
at
::
ScalarType
>
maybe_token_scales_type
,
c10
::
optional
<
at
::
ScalarType
>
maybe_out_type
)
{
std
::
optional
<
at
::
ScalarType
>
maybe_group_scales_type
,
std
::
optional
<
at
::
ScalarType
>
maybe_group_zeros_type
,
std
::
optional
<
at
::
ScalarType
>
maybe_channel_scales_type
,
std
::
optional
<
at
::
ScalarType
>
maybe_token_scales_type
,
std
::
optional
<
at
::
ScalarType
>
maybe_out_type
)
{
ScalarType
const
b_type
=
ScalarType
::
from_id
(
b_type_id
);
return
supported_schedules_dispatch
({
.
a_type
=
a_type
,
...
...
@@ -29,13 +29,13 @@ std::vector<std::string> supported_schedules(
torch
::
Tensor
mm
(
torch
::
Tensor
const
&
A
,
torch
::
Tensor
const
&
B
,
int64_t
b_type_id
,
c10
::
optional
<
at
::
ScalarType
>
const
&
maybe_out_type
,
c10
::
optional
<
torch
::
Tensor
>
const
&
maybe_group_scales
,
c10
::
optional
<
torch
::
Tensor
>
const
&
maybe_group_zeros
,
c10
::
optional
<
int64_t
>
maybe_group_size
,
c10
::
optional
<
torch
::
Tensor
>
const
&
maybe_channel_scales
,
c10
::
optional
<
torch
::
Tensor
>
const
&
maybe_token_scales
,
c10
::
optional
<
std
::
string
>
maybe_schedule
)
{
std
::
optional
<
at
::
ScalarType
>
const
&
maybe_out_type
,
std
::
optional
<
torch
::
Tensor
>
const
&
maybe_group_scales
,
std
::
optional
<
torch
::
Tensor
>
const
&
maybe_group_zeros
,
std
::
optional
<
int64_t
>
maybe_group_size
,
std
::
optional
<
torch
::
Tensor
>
const
&
maybe_channel_scales
,
std
::
optional
<
torch
::
Tensor
>
const
&
maybe_token_scales
,
std
::
optional
<
std
::
string
>
maybe_schedule
)
{
ScalarType
const
b_type
=
ScalarType
::
from_id
(
b_type_id
);
return
mm_dispatch
({.
A
=
A
,
.
B
=
B
,
...
...
@@ -51,7 +51,7 @@ torch::Tensor mm(torch::Tensor const& A, torch::Tensor const& B,
torch
::
Tensor
prepack_B
(
torch
::
Tensor
const
&
B
,
at
::
ScalarType
const
&
a_type
,
int64_t
b_type_id
,
c10
::
optional
<
at
::
ScalarType
>
const
&
maybe_group_scales_type
)
{
std
::
optional
<
at
::
ScalarType
>
const
&
maybe_group_scales_type
)
{
ScalarType
const
b_type
=
ScalarType
::
from_id
(
b_type_id
);
return
prepack_B_dispatch
(
{.
B
=
B
,
...
...
csrc/quantization/marlin/dense/marlin_cuda_kernel.cu
View file @
afd0da21
...
...
@@ -96,8 +96,8 @@ __device__ inline FragB dequant(int q) {
const
int
HI
=
0x00f000f0
;
const
int
EX
=
0x64006400
;
// Guarantee that the `(a & b) | c` operations are LOP3s.
int
lo
=
lop3
<
(
0xf0
&
0xcc
)
|
0xaa
>
(
q
,
LO
,
EX
);
int
hi
=
lop3
<
(
0xf0
&
0xcc
)
|
0xaa
>
(
q
,
HI
,
EX
);
int
lo
=
lop3
<
(
0xf0
&
0xcc
)
|
0xaa
>
(
q
,
LO
,
EX
);
int
hi
=
lop3
<
(
0xf0
&
0xcc
)
|
0xaa
>
(
q
,
HI
,
EX
);
// We want signed int4 outputs, hence we fuse the `-8` symmetric zero point
// directly into `SUB` and `ADD`.
const
int
SUB
=
0x64086408
;
...
...
Prev
1
2
3
4
5
6
7
8
9
10
…
30
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment