Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
260d119e
Unverified
Commit
260d119e
authored
Jun 01, 2024
by
Tyler Michael Smith
Committed by
GitHub
Jun 01, 2024
Browse files
[Kernel] Refactor CUTLASS kernels to always take scales that reside on the GPU (#5137)
parent
a360ff80
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
445 additions
and
76 deletions
+445
-76
csrc/quantization/cutlass_w8a8/broadcast_load_epilogue_c2x.hpp
...quantization/cutlass_w8a8/broadcast_load_epilogue_c2x.hpp
+28
-22
csrc/quantization/cutlass_w8a8/broadcast_load_epilogue_c3x.hpp
...quantization/cutlass_w8a8/broadcast_load_epilogue_c3x.hpp
+389
-0
csrc/quantization/cutlass_w8a8/scaled_mm_dq_c2x.cu
csrc/quantization/cutlass_w8a8/scaled_mm_dq_c2x.cu
+4
-10
csrc/quantization/cutlass_w8a8/scaled_mm_dq_c3x.cu
csrc/quantization/cutlass_w8a8/scaled_mm_dq_c3x.cu
+10
-10
pyproject.toml
pyproject.toml
+1
-1
tests/kernels/test_cutlass.py
tests/kernels/test_cutlass.py
+10
-3
vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_statictensor.py
...d_tensors/schemes/compressed_tensors_w8a8_statictensor.py
+3
-30
No files found.
csrc/quantization/cutlass_w8a8/
cutlass_visitor_2x_
broadcast_epilogue.hpp
→
csrc/quantization/cutlass_w8a8/broadcast_
load_
epilogue
_c2x
.hpp
View file @
260d119e
...
...
@@ -33,20 +33,27 @@
//
// This file is a modified excerpt of
// include/cutlass/epilogue/fusion/visitor_load.hpp from
// https://github.com/NVIDIA/cutlass It's beem modified to support either
// row/column or scalar broadcasting, like is already supported in CUTLASS 3.x.
// Important because this saves us a factor 4x on the number of kernels
// compiled.
// https://github.com/NVIDIA/cutlass v3.5.0
// It has been modified to support either
// row/column or scalar broadcasting where the tensor being loaded from is
// always passed in via a device pointer. This lets one compiled kernel handle
// all cases of per-tensor or per-channel/per-token quantization.
//
// This interface also allows the scales to be passed in as tensors that
// consistently reside on the device, which avoids an issue with a previous
// implementation where scalars needed to be on the CPU since they
// were passed in via float values. This created a potential performance hazard
// if scales were initially on the device, and caused torch.compile graph
// breaks when moving scales to the CPU.
//
#pragma once
// Turn off clang-format for the entire file to keep it close to upstream
// clang-format off
#include "cutlass/epilogue/threadblock/fusion/visitor_2x.hpp"
#include "cute/tensor.hpp"
// clang-format on
namespace
cutlass
::
epilogue
::
threadblock
{
using
namespace
cute
;
...
...
@@ -59,9 +66,11 @@ template<
>
struct
VisitorRowOrScalarBroadcast
{
// This struct has been modified to have a bool indicating that ptr_row is a
// scalar that must be broadcast.
struct
Arguments
{
Element
const
*
ptr_row
=
nullptr
;
Element
null_default
=
Element
(
0
)
;
bool
row_broadcast
=
true
;
StrideMNL
dRow
=
{};
};
...
...
@@ -125,25 +134,25 @@ struct VisitorRowOrScalarBroadcast {
auto
coord_v
=
filter
(
tC_cRow
);
auto
dst_v
=
filter
(
tC_rRow
);
if
(
params_ptr
->
ptr_row
)
{
if
(
params_ptr
->
row_broadcast
)
{
// In this case we are loading from a row vector and broadcasting
CUTLASS_PRAGMA_UNROLL
for
(
int
i
=
0
;
i
<
size
(
src_v
);
++
i
)
{
bool
guard
=
get
<
1
>
(
coord_v
(
i
))
<
n
;
cutlass
::
arch
::
global_load
<
VecType
,
sizeof
(
VecType
)
>
(
dst_v
(
i
),
(
void
const
*
)
&
src_v
(
i
),
guard
);
cutlass
::
arch
::
global_load
<
VecType
,
sizeof
(
VecType
)
>
(
dst_v
(
i
),
(
void
const
*
)
&
src_v
(
i
),
guard
);
}
}
else
{
// In this case we are loading from a scalar and broadcasting
VecType
filled_vec
;
CUTLASS_PRAGMA_UNROLL
for
(
int
i
=
0
;
i
<
VecLength
;
i
++
)
{
reinterpret_cast
<
Element
*>
(
&
filled_vec
)[
i
]
=
params_ptr
->
null_default
;
reinterpret_cast
<
Element
*>
(
&
filled_vec
)[
i
]
=
*
(
params_ptr
->
ptr_row
)
;
}
CUTLASS_PRAGMA_UNROLL
for
(
int
i
=
0
;
i
<
size
(
src_v
);
++
i
)
{
if
(
get
<
1
>
(
coord_v
(
i
))
<
n
)
{
if
(
get
<
1
>
(
coord_v
(
i
))
<
n
)
{
dst_v
(
i
)
=
filled_vec
;
}
}
...
...
@@ -208,9 +217,11 @@ template<
>
struct
VisitorColOrScalarBroadcast
{
// This struct has been modified to have a bool indicating that ptr_col is a
// scalar that must be broadcast.
struct
Arguments
{
Element
const
*
ptr_col
=
nullptr
;
Element
null_default
=
Element
(
0
)
;
bool
col_broadcast
=
true
;
StrideMNL
dCol
=
{};
};
...
...
@@ -230,11 +241,6 @@ struct VisitorColOrScalarBroadcast {
struct
SharedStorage
{
};
// Global load type
static
int
constexpr
vec_bits
=
ThreadMap
::
kElementsPerAccess
*
sizeof_bits
<
Element
>::
value
;
using
VecType
=
uint_bit_t
<
cute
::
min
(
128
,
vec_bits
)
>
;
static
int
constexpr
VecLength
=
sizeof
(
VecType
)
/
sizeof
(
Element
);
CUTLASS_HOST_DEVICE
VisitorColOrScalarBroadcast
()
{
}
...
...
@@ -267,7 +273,7 @@ struct VisitorColOrScalarBroadcast {
int
m
;
// This function is modified from VisitorColBroadcast
CUTLASS_DEVICE
void
CUTLASS_DEVICE
void
begin_epilogue
()
{
clear
(
tC_rCol
);
...
...
@@ -277,7 +283,7 @@ struct VisitorColOrScalarBroadcast {
pred
(
i
)
=
get
<
0
>
(
tC_cCol
(
i
))
<
m
;
}
if
(
params_ptr
->
ptr_col
)
{
if
(
params_ptr
->
col_broadcast
)
{
// In this case we are loading from a column vector and broadcasting
copy_if
(
pred
,
tC_gCol
,
tC_rCol
);
}
else
{
...
...
@@ -286,8 +292,8 @@ struct VisitorColOrScalarBroadcast {
CUTLASS_PRAGMA_UNROLL
for
(
int
i
=
0
;
i
<
size
(
dst_v
);
++
i
)
{
if
(
pred
(
i
)){
dst_v
(
i
)
=
params_ptr
->
null_default
;
if
(
pred
(
i
))
{
dst_v
(
i
)
=
*
(
params_ptr
->
ptr_col
)
;
}
}
}
...
...
csrc/quantization/cutlass_w8a8/broadcast_load_epilogue_c3x.hpp
0 → 100644
View file @
260d119e
/***************************************************************************************************
* Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights
*reserved. SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice,
*this list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
*ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
*LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
*CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
*SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
*INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
*CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
*ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
*POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
//
// This file is a modified excerpt of
// include/cutlass/epilogue/fusion/sm90_visitor_load_tma_warpspecialized.hpp
// from https://github.com/NVIDIA/cutlass v3.5.0
// It has been modified to support either row/column or scalar broadcasting
// where the tensor being loaded from is always passed in via a device pointer.
// This lets one compiled kernel handle all cases of per-tensor or
// per-channel/per-token quantization.
//
// This interface also allows the scales to be passed in as tensors that
// consistently reside on the device, which avoids an issue with a previous
// implementation where scalars needed to be on the CPU since they
// were passed in via float values. This created a potential performance hazard
// if scales were initially on the device, and caused torch.compile graphs
// breaks when moving scales to the CPU.
//
#pragma once
// Turn off clang-format for the entire file to keep it close to upstream
// clang-format off
#include "cutlass/cutlass.h"
#include "cutlass/arch/barrier.h"
#include "cute/tensor.hpp"
#include "cutlass/epilogue/fusion/sm90_visitor_tma_warpspecialized.hpp"
namespace
cutlass
::
epilogue
::
fusion
{
using
namespace
cute
;
using
namespace
detail
;
// Row vector broadcast
template
<
// Row bcast reuses the mbarriers from the epilogue subtile load pipeline, so this must be at least
// ceil_div(StagesC, epi tiles per CTA tile) + 1 to ensure no data races
int
Stages
,
class
CtaTileShapeMNK
,
class
Element
,
class
StrideMNL
=
Stride
<
_0
,
_1
,
_0
>,
int
Alignment
=
128
/
sizeof_bits_v
<
Element
>
>
struct
Sm90RowOrScalarBroadcast
{
static_assert
(
Alignment
*
sizeof_bits_v
<
Element
>
%
128
==
0
,
"sub-16B alignment not supported yet"
);
static_assert
(
(
cute
::
is_same_v
<
StrideMNL
,
Stride
<
_0
,
_1
,
_0
>>
)
||
// row vector broadcast, e.g. per-col alpha/bias
(
cute
::
is_same_v
<
StrideMNL
,
Stride
<
_0
,
_1
,
int
>>
));
// batched row vector broadcast
// Accumulator doesn't distribute row elements evenly amongst threads so we must buffer in smem
struct
SharedStorage
{
alignas
(
16
)
array_aligned
<
Element
,
size
<
1
>
(
CtaTileShapeMNK
{})
*
Stages
>
smem_row
;
};
// This struct has been modified to have a bool indicating that ptr_row is a
// scalar that must be broadcast, instead of containing a scalar that is
// valid if ptr_row is null.
struct
Arguments
{
Element
const
*
ptr_row
=
nullptr
;
bool
row_broadcast
=
true
;
StrideMNL
dRow
=
{};
};
using
Params
=
Arguments
;
template
<
class
ProblemShape
>
static
constexpr
Params
to_underlying_arguments
(
ProblemShape
const
&
problem_shape
,
Arguments
const
&
args
,
void
*
workspace
)
{
return
args
;
}
template
<
class
ProblemShape
>
static
size_t
get_workspace_size
(
ProblemShape
const
&
problem_shape
,
Arguments
const
&
args
)
{
return
0
;
}
template
<
class
ProblemShape
>
static
cutlass
::
Status
initialize_workspace
(
ProblemShape
const
&
problem_shape
,
Arguments
const
&
args
,
void
*
workspace
,
cudaStream_t
stream
,
CudaHostAdapter
*
cuda_adapter
=
nullptr
)
{
return
cutlass
::
Status
::
kSuccess
;
}
CUTLASS_HOST_DEVICE
Sm90RowOrScalarBroadcast
()
{
}
CUTLASS_HOST_DEVICE
Sm90RowOrScalarBroadcast
(
Params
const
&
params
,
SharedStorage
const
&
shared_storage
)
:
params
(
params
),
smem_row
(
const_cast
<
Element
*>
(
shared_storage
.
smem_row
.
data
()))
{
}
Params
params
;
Element
*
smem_row
;
CUTLASS_DEVICE
bool
is_producer_load_needed
()
const
{
return
true
;
}
CUTLASS_DEVICE
bool
is_C_load_needed
()
const
{
return
false
;
}
CUTLASS_DEVICE
bool
is_zero
()
const
{
return
(
!
params
.
row_broadcast
&&
*
(
params
.
ptr_row
)
==
Element
(
0
));
}
template
<
int
EpiTiles
,
class
GTensor
,
class
STensor
>
struct
ProducerLoadCallbacks
:
EmptyProducerLoadCallbacks
{
CUTLASS_DEVICE
ProducerLoadCallbacks
(
GTensor
&&
gRow
,
STensor
&&
sRow
,
Params
const
&
params
)
:
gRow
(
cute
::
forward
<
GTensor
>
(
gRow
)),
sRow
(
cute
::
forward
<
STensor
>
(
sRow
)),
params
(
params
)
{}
GTensor
gRow
;
// (CTA_M,CTA_N)
STensor
sRow
;
// (CTA_M,CTA_N,PIPE)
Params
const
&
params
;
CUTLASS_DEVICE
void
begin
(
uint64_t
*
full_mbarrier_ptr
,
int
load_iteration
,
bool
issue_tma_load
)
{
if
(
params
.
ptr_row
==
nullptr
)
{
return
;
}
if
(
issue_tma_load
)
{
// Increment the expect-tx count of the first subtile's mbarrier by the row vector's byte-size
constexpr
uint32_t
copy_bytes
=
size
<
1
>
(
CtaTileShapeMNK
{})
*
sizeof_bits_v
<
Element
>
/
8
;
cutlass
::
arch
::
ClusterTransactionBarrier
::
expect_transaction
(
full_mbarrier_ptr
,
copy_bytes
);
// Issue the TMA bulk copy
auto
bulk_copy
=
Copy_Atom
<
SM90_BULK_COPY_AUTO
,
Element
>
{}.
with
(
*
full_mbarrier_ptr
);
// Filter so we don't issue redundant copies over stride-0 modes
int
bcast_pipe_index
=
(
load_iteration
/
EpiTiles
)
%
Stages
;
copy
(
bulk_copy
,
filter
(
gRow
),
filter
(
sRow
(
_
,
_
,
bcast_pipe_index
)));
}
}
};
template
<
class
...
Args
>
CUTLASS_DEVICE
auto
get_producer_load_callbacks
(
ProducerLoadArgs
<
Args
...
>
const
&
args
)
{
auto
[
M
,
N
,
K
,
L
]
=
args
.
problem_shape_mnkl
;
auto
[
m
,
n
,
k
,
l
]
=
args
.
tile_coord_mnkl
;
Tensor
mRow
=
make_tensor
(
make_gmem_ptr
(
params
.
ptr_row
),
make_shape
(
M
,
N
,
L
),
params
.
dRow
);
Tensor
gRow
=
local_tile
(
mRow
,
take
<
0
,
2
>
(
args
.
tile_shape_mnk
),
make_coord
(
m
,
n
,
l
));
// (CTA_M,CTA_N)
Tensor
sRow
=
make_tensor
(
make_smem_ptr
(
smem_row
),
// (CTA_M,CTA_N,PIPE)
make_shape
(
size
<
0
>
(
CtaTileShapeMNK
{}),
size
<
1
>
(
CtaTileShapeMNK
{}),
Stages
),
make_stride
(
_0
{},
_1
{},
size
<
1
>
(
CtaTileShapeMNK
{})));
constexpr
int
EpiTiles
=
decltype
(
size
<
1
>
(
zipped_divide
(
make_layout
(
take
<
0
,
2
>
(
args
.
tile_shape_mnk
)),
args
.
epi_tile
)))
::
value
;
return
ProducerLoadCallbacks
<
EpiTiles
,
decltype
(
gRow
),
decltype
(
sRow
)
>
(
cute
::
move
(
gRow
),
cute
::
move
(
sRow
),
params
);
}
template
<
int
EpiTiles
,
class
RTensor
,
class
STensor
>
struct
ConsumerStoreCallbacks
:
EmptyConsumerStoreCallbacks
{
CUTLASS_DEVICE
ConsumerStoreCallbacks
(
RTensor
&&
tCrRow
,
STensor
&&
tCsRow
,
Params
const
&
params
)
:
tCrRow
(
cute
::
forward
<
RTensor
>
(
tCrRow
)),
tCsRow
(
cute
::
forward
<
STensor
>
(
tCsRow
)),
params
(
params
)
{}
RTensor
tCrRow
;
// (CPY,CPY_M,CPY_N)
STensor
tCsRow
;
// (CPY,CPY_M,CPY_N,EPI_M,EPI_N,PIPE)
Params
const
&
params
;
CUTLASS_DEVICE
void
previsit
(
int
epi_m
,
int
epi_n
,
int
load_iteration
,
bool
is_producer_load_needed
)
{
if
(
!
params
.
row_broadcast
)
{
fill
(
tCrRow
,
*
(
params
.
ptr_row
));
return
;
}
if
(
epi_m
==
0
)
{
// Assumes M-major subtile loop
// Filter so we don't issue redundant copies over stride-0 modes
// (only works if 0-strides are in same location, which is by construction)
int
bcast_pipe_index
=
(
load_iteration
/
EpiTiles
)
%
Stages
;
copy_aligned
(
filter
(
tCsRow
(
_
,
_
,
_
,
epi_m
,
epi_n
,
bcast_pipe_index
)),
filter
(
tCrRow
));
}
}
template
<
typename
ElementAccumulator
,
int
FragmentSize
>
CUTLASS_DEVICE
Array
<
Element
,
FragmentSize
>
visit
(
Array
<
ElementAccumulator
,
FragmentSize
>
const
&
frg_acc
,
int
epi_v
,
int
epi_m
,
int
epi_n
)
{
Array
<
Element
,
FragmentSize
>
frg_row
;
CUTLASS_PRAGMA_UNROLL
for
(
int
i
=
0
;
i
<
FragmentSize
;
++
i
)
{
frg_row
[
i
]
=
tCrRow
(
epi_v
*
FragmentSize
+
i
);
}
return
frg_row
;
}
};
template
<
bool
ReferenceSrc
,
// do register tensors reference the src or dst layout of the tiled copy
class
...
Args
>
CUTLASS_DEVICE
auto
get_consumer_store_callbacks
(
ConsumerStoreArgs
<
Args
...
>
const
&
args
)
{
Tensor
sRow
=
make_tensor
(
make_smem_ptr
(
smem_row
),
// (CTA_M,CTA_N,PIPE)
make_shape
(
size
<
0
>
(
CtaTileShapeMNK
{}),
size
<
1
>
(
CtaTileShapeMNK
{}),
Stages
),
make_stride
(
_0
{},
_1
{},
size
<
1
>
(
CtaTileShapeMNK
{})));
Tensor
tCsRow
=
sm90_partition_for_epilogue
<
ReferenceSrc
>
(
// (CPY,CPY_M,CPY_N,EPI_M,EPI_N,PIPE)
sRow
,
args
.
epi_tile
,
args
.
tiled_copy
,
args
.
thread_idx
);
Tensor
tCrRow
=
make_tensor_like
(
take
<
0
,
3
>
(
tCsRow
));
// (CPY,CPY_M,CPY_N)
constexpr
int
EpiTiles
=
decltype
(
size
<
1
>
(
zipped_divide
(
make_layout
(
take
<
0
,
2
>
(
args
.
tile_shape_mnk
)),
args
.
epi_tile
)))
::
value
;
return
ConsumerStoreCallbacks
<
EpiTiles
,
decltype
(
tCrRow
),
decltype
(
tCsRow
)
>
(
cute
::
move
(
tCrRow
),
cute
::
move
(
tCsRow
),
params
);
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
// Column vector broadcast
template
<
int
Stages
,
class
CtaTileShapeMNK
,
class
Element
,
class
StrideMNL
=
Stride
<
_1
,
_0
,
_0
>,
int
Alignment
=
128
/
sizeof_bits_v
<
Element
>
>
struct
Sm90ColOrScalarBroadcast
{
static_assert
(
Stages
==
0
,
"Column broadcast doesn't support smem usage yet"
);
static_assert
(
Alignment
*
sizeof_bits_v
<
Element
>
%
128
==
0
,
"sub-16B alignment not supported yet"
);
static_assert
(
(
cute
::
is_same_v
<
StrideMNL
,
Stride
<
_1
,
_0
,
_0
>>
)
||
// col vector broadcast, e.g. per-row alpha/bias
(
cute
::
is_same_v
<
StrideMNL
,
Stride
<
_1
,
_0
,
int
>>
));
// batched col vector broadcast, e.g. batched per-row bias
// Accumulator distributes col elements evenly amongst threads so we can just directly load from gmem
struct
SharedStorage
{
};
// This struct has been modified to have a bool indicating that ptr_col is a
// scalar that must be broadcast, instead of containing a scalar that is
// valid if ptr_col is null.
struct
Arguments
{
Element
const
*
ptr_col
=
nullptr
;
bool
col_broadcast
=
true
;
StrideMNL
dCol
=
{};
};
using
Params
=
Arguments
;
template
<
class
ProblemShape
>
static
constexpr
Params
to_underlying_arguments
(
ProblemShape
const
&
problem_shape
,
Arguments
const
&
args
,
void
*
workspace
)
{
return
args
;
}
template
<
class
ProblemShape
>
static
size_t
get_workspace_size
(
ProblemShape
const
&
problem_shape
,
Arguments
const
&
args
)
{
return
0
;
}
template
<
class
ProblemShape
>
static
cutlass
::
Status
initialize_workspace
(
ProblemShape
const
&
problem_shape
,
Arguments
const
&
args
,
void
*
workspace
,
cudaStream_t
stream
,
CudaHostAdapter
*
cuda_adapter
=
nullptr
)
{
return
cutlass
::
Status
::
kSuccess
;
}
CUTLASS_DEVICE
bool
is_producer_load_needed
()
const
{
return
false
;
}
CUTLASS_DEVICE
bool
is_C_load_needed
()
const
{
return
false
;
}
CUTLASS_DEVICE
bool
is_zero
()
const
{
return
(
!
params
.
col_broadcast
&&
*
(
params
.
ptr_col
)
==
Element
(
0
));
}
CUTLASS_HOST_DEVICE
Sm90ColOrScalarBroadcast
()
{
}
CUTLASS_HOST_DEVICE
Sm90ColOrScalarBroadcast
(
Params
const
&
params
,
SharedStorage
const
&
shared_storage
)
:
params
(
params
)
{
}
Params
params
;
template
<
class
...
Args
>
CUTLASS_DEVICE
auto
get_producer_load_callbacks
(
ProducerLoadArgs
<
Args
...
>
const
&
args
)
{
return
EmptyProducerLoadCallbacks
{};
}
template
<
class
GTensor
,
class
RTensor
>
struct
ConsumerStoreCallbacks
:
EmptyConsumerStoreCallbacks
{
CUTLASS_DEVICE
ConsumerStoreCallbacks
(
GTensor
&&
tCgCol
,
RTensor
&&
tCrCol
,
Params
const
&
params
)
:
tCgCol
(
cute
::
forward
<
GTensor
>
(
tCgCol
)),
tCrCol
(
cute
::
forward
<
RTensor
>
(
tCrCol
)),
params
(
params
)
{}
GTensor
tCgCol
;
// (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
RTensor
tCrCol
;
// (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
Params
const
&
params
;
CUTLASS_DEVICE
void
begin
()
{
if
(
!
params
.
col_broadcast
)
{
fill
(
tCrCol
,
*
(
params
.
ptr_col
));
return
;
}
// Filter so we don't issue redundant copies over stride-0 modes
// (only works if 0-strides are in same location, which is by construction)
copy_aligned
(
filter
(
tCgCol
),
filter
(
tCrCol
));
}
template
<
typename
ElementAccumulator
,
int
FragmentSize
>
CUTLASS_DEVICE
Array
<
Element
,
FragmentSize
>
visit
(
Array
<
ElementAccumulator
,
FragmentSize
>
const
&
frg_acc
,
int
epi_v
,
int
epi_m
,
int
epi_n
)
{
Array
<
Element
,
FragmentSize
>
frg_col
;
Tensor
tCrCol_mn
=
tCrCol
(
_
,
_
,
_
,
epi_m
,
epi_n
);
CUTLASS_PRAGMA_UNROLL
for
(
int
i
=
0
;
i
<
FragmentSize
;
++
i
)
{
frg_col
[
i
]
=
tCrCol_mn
(
epi_v
*
FragmentSize
+
i
);
}
return
frg_col
;
}
};
template
<
bool
ReferenceSrc
,
// do register tensors reference the src or dst layout of the tiled copy
class
...
Args
>
CUTLASS_DEVICE
auto
get_consumer_store_callbacks
(
ConsumerStoreArgs
<
Args
...
>
const
&
args
)
{
auto
[
M
,
N
,
K
,
L
]
=
args
.
problem_shape_mnkl
;
Tensor
mCol
=
make_tensor
(
make_gmem_ptr
(
params
.
ptr_col
),
make_shape
(
M
,
N
,
L
),
params
.
dCol
);
Tensor
tCgCol
=
sm90_partition_for_epilogue
<
ReferenceSrc
>
(
// (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
mCol
,
args
.
tile_shape_mnk
,
args
.
tile_coord_mnkl
,
args
.
epi_tile
,
args
.
tiled_copy
,
args
.
thread_idx
);
Tensor
tCrCol
=
make_tensor_like
(
tCgCol
);
// (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
return
ConsumerStoreCallbacks
<
decltype
(
tCgCol
),
decltype
(
tCrCol
)
>
(
cute
::
move
(
tCgCol
),
cute
::
move
(
tCrCol
),
params
);
}
};
}
csrc/quantization/cutlass_w8a8/scaled_mm_dq_c2x.cu
View file @
260d119e
...
...
@@ -22,7 +22,7 @@
#include "cutlass/epilogue/threadblock/fusion/visitors.hpp"
#include "cutlass/gemm/kernel/default_gemm_universal_with_visitor.h"
#include "
cutlass_visitor_2x_
broadcast_epilogue.hpp"
#include "broadcast_
load_
epilogue
_c2x
.hpp"
#include "common.hpp"
// clang-format on
...
...
@@ -145,17 +145,11 @@ void cutlass_scaled_mm_dq_dispatcher(torch::Tensor& out, torch::Tensor const& a,
auto
a_scales_ptr
=
a_scales
.
data_ptr
<
float
>
();
auto
b_scales_ptr
=
b_scales
.
data_ptr
<
float
>
();
// If A and B are quantized per-tensor, then these scale tensors are scalars,
// and they are passed in via the second argument.
using
ScaleAArgs
=
typename
Gemm
::
ScaleA
::
Arguments
;
ScaleAArgs
a_args
=
a_scales
.
numel
()
==
1
?
ScaleAArgs
{
nullptr
,
a_scales
.
item
<
float
>
(),
{}}
:
ScaleAArgs
{
a_scales
.
data_ptr
<
float
>
(),
{},
{}};
using
ScaleBArgs
=
typename
Gemm
::
ScaleB
::
Arguments
;
ScaleBArgs
b_args
=
b_scales
.
numel
()
==
1
?
ScaleBArgs
{
nullptr
,
b_scales
.
item
<
float
>
()
,
{}}
:
Scale
B
Args
{
b
_scales
.
data_ptr
<
float
>
(),
{}
,
{}};
ScaleBArgs
b_args
{
b_scales
.
data_ptr
<
float
>
(),
b_scales
.
numel
()
!=
1
,
{}}
;
Scale
A
Args
a_args
{
a
_scales
.
data_ptr
<
float
>
(),
a_scales
.
numel
()
!=
1
,
{}};
typename
Gemm
::
EVTCompute0
::
Arguments
evt0_compute_args
{
b_args
};
...
...
csrc/quantization/cutlass_w8a8/scaled_mm_dq_c3x.cu
View file @
260d119e
...
...
@@ -18,11 +18,14 @@
#include "cute/atom/mma_atom.hpp"
#include "cutlass/numeric_types.h"
#include "cutlass/util/device_memory.h"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include "cutlass/gemm/kernel/gemm_universal.hpp"
#include "cutlass/epilogue/collective/collective_builder.hpp"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "broadcast_load_epilogue_c3x.hpp"
#include "common.hpp"
// clang-format on
...
...
@@ -65,7 +68,7 @@ struct cutlass_3x_gemm {
using
Accum
=
cutlass
::
epilogue
::
fusion
::
Sm90AccFetch
;
using
ScaleA
=
cutlass
::
epilogue
::
fusion
::
Sm90ColBroadcast
<
using
ScaleA
=
cutlass
::
epilogue
::
fusion
::
Sm90Col
OrScalar
Broadcast
<
0
/*Stages*/
,
typename
EpilogueDescriptor
::
TileShape
,
float
,
Stride
<
Int
<
1
>
,
Int
<
0
>
,
Int
<
0
>>>
;
...
...
@@ -73,7 +76,7 @@ struct cutlass_3x_gemm {
cutlass
::
epilogue
::
collective
::
detail
::
RowBroadcastDescriptor
<
EpilogueDescriptor
,
float
>
;
using
ScaleB
=
cutlass
::
epilogue
::
fusion
::
Sm90RowBroadcast
<
using
ScaleB
=
cutlass
::
epilogue
::
fusion
::
Sm90Row
OrScalar
Broadcast
<
ScaleBDescriptor
::
Stages
,
typename
EpilogueDescriptor
::
TileShape
,
typename
ScaleBDescriptor
::
Element
,
Stride
<
Int
<
0
>
,
Int
<
1
>
,
Int
<
0
>>>
;
...
...
@@ -166,13 +169,9 @@ void cutlass_scaled_mm_dq_dispatcher(torch::Tensor& out, torch::Tensor const& a,
using
ScaleA_Args
=
typename
Gemm
::
ScaleA
::
Arguments
;
using
ScaleB_Args
=
typename
Gemm
::
ScaleB
::
Arguments
;
ScaleA_Args
a_args
=
a_scales
.
numel
()
==
1
?
ScaleA_Args
{
nullptr
,
a_scales
.
item
<
float
>
(),
{}}
:
ScaleA_Args
{
a_scales
.
data_ptr
<
float
>
(),
{},
{}};
ScaleB_Args
b_args
=
b_scales
.
numel
()
==
1
?
ScaleB_Args
{
nullptr
,
b_scales
.
item
<
float
>
(),
{}}
:
ScaleB_Args
{
b_scales
.
data_ptr
<
float
>
(),
{},
{}};
ScaleA_Args
a_args
{
a_scales
.
data_ptr
<
float
>
(),
a_scales
.
numel
()
!=
1
,
{}};
ScaleB_Args
b_args
{
b_scales
.
data_ptr
<
float
>
(),
b_scales
.
numel
()
!=
1
,
{}};
args
.
epilogue
.
thread
=
{
a_args
,
{
b_args
}};
...
...
@@ -182,10 +181,11 @@ void cutlass_scaled_mm_dq_dispatcher(torch::Tensor& out, torch::Tensor const& a,
CUTLASS_CHECK
(
gemm_op
.
can_implement
(
args
));
size_t
workspace_size
=
gemm_op
.
get_workspace_size
(
args
);
TORCH_CHECK
(
workspace_size
==
0
);
cutlass
::
device_memory
::
allocation
<
uint8_t
>
workspace
(
workspace_size
);
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
(
a
.
get_device
());
cutlass
::
Status
status
=
gemm_op
.
run
(
args
,
stream
);
cutlass
::
Status
status
=
gemm_op
.
run
(
args
,
workspace
.
get
(),
stream
);
CUTLASS_CHECK
(
status
);
}
}
// namespace
...
...
pyproject.toml
View file @
260d119e
...
...
@@ -59,7 +59,7 @@ exclude = [
]
[tool.codespell]
ignore-words-list
=
"dout, te, indicies"
ignore-words-list
=
"dout, te, indicies
, subtile
"
skip
=
"./tests/prompts,./benchmarks/sonnet.txt,./tests/lora/data,./build"
[tool.isort]
...
...
tests/kernels/test_cutlass.py
View file @
260d119e
...
...
@@ -207,14 +207,21 @@ class CutlassLayer(torch.nn.Module):
self
.
out_dtype
)
def
test_cutlass_cuda_graph
():
@
pytest
.
mark
.
parametrize
(
"per_act_token"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"per_out_ch"
,
[
True
,
False
])
def
test_cutlass_cuda_graph
(
per_act_token
:
bool
,
per_out_ch
:
bool
):
m
,
n
,
k
=
512
,
512
,
512
a
=
to_int8
(
torch
.
randn
((
m
,
k
),
device
=
"cuda"
))
b
=
to_int8
(
torch
.
randn
((
n
,
k
),
device
=
"cuda"
).
t
())
scale_a
=
(
torch
.
randn
((
m
,
1
),
device
=
"cuda"
,
dtype
=
torch
.
float32
)
/
10
)
scale_b
=
(
torch
.
randn
((
1
,
n
),
device
=
"cuda"
,
dtype
=
torch
.
float32
)
/
10
)
m_a_scales
=
m
if
per_act_token
else
1
n_b_scales
=
n
if
per_out_ch
else
1
scale_a
=
(
torch
.
randn
(
(
m_a_scales
,
1
),
device
=
"cuda"
,
dtype
=
torch
.
float32
)
/
10
)
scale_b
=
(
torch
.
randn
(
(
1
,
n_b_scales
),
device
=
"cuda"
,
dtype
=
torch
.
float32
)
/
10
)
# Construct a trivial model with a single layer that calls a CUTLASS kernel
model
=
CutlassLayer
(
b
,
scale_a
,
scale_b
,
torch
.
bfloat16
)
...
...
vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_statictensor.py
View file @
260d119e
...
...
@@ -41,46 +41,19 @@ class CompressedTensorsW8A8StaticTensor(CompressedTensorsScheme):
# TODO: remove zero_point parameters once the configs given remove them
# Note on input/weight scales and zero_points
#
# When the scales have a single value, it is required that they be
# on the CPU for 2 reasons,
# 1. Performance:
# When the scales (input_scale/weight_scales) have only a single
# value, we perform a scalar broadcast of that value during the
# quant/dequant operations. The "quant" and the "gemm+dequant"
# kernels accept the Scalar by-value. These tensors are allocated
# on the CPU in order to avoid the GPU-to-CPU copy when passing
# by-value.
#
# 2. CUDA Graphs:
# CUDA Graphs don't support GPU-to-CPU copy operations during
# stream capture.
#
# TODO: zero-points are not supported yet. But we expect a similar
# pattern.
is_tensor_partitioned
=
len
(
output_partition_sizes
)
!=
1
weight_scale_dim
=
sum
(
output_partition_sizes
)
if
is_tensor_partitioned
else
1
weight_scale_device
=
"cpu"
if
weight_scale_dim
==
1
else
"cuda"
input_scale
=
Parameter
(
torch
.
empty
(
1
,
device
=
"cpu"
,
dtype
=
torch
.
float32
),
input_scale
=
Parameter
(
torch
.
empty
(
1
,
dtype
=
torch
.
float32
),
requires_grad
=
False
)
input_zero_point
=
Parameter
(
torch
.
empty
(
1
,
device
=
"cpu"
,
dtype
=
torch
.
int8
),
input_zero_point
=
Parameter
(
torch
.
empty
(
1
,
dtype
=
torch
.
int8
),
requires_grad
=
False
)
weight_scale
=
Parameter
(
torch
.
empty
(
weight_scale_dim
,
device
=
weight_scale_device
,
dtype
=
torch
.
float32
),
requires_grad
=
False
)
weight_zero_point
=
Parameter
(
torch
.
empty
(
1
,
device
=
"cpu"
,
dtype
=
torch
.
int8
),
weight_zero_point
=
Parameter
(
torch
.
empty
(
1
,
dtype
=
torch
.
int8
),
requires_grad
=
False
)
weight
=
Parameter
(
torch
.
empty
(
sum
(
output_partition_sizes
),
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment