Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
55712941
Unverified
Commit
55712941
authored
Jul 26, 2024
by
Lucas Wilkinson
Committed by
GitHub
Jul 27, 2024
Browse files
[Bug Fix] Illegal memory access, FP8 Llama 3.1 405b (#6852)
parent
981b0d56
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
37 additions
and
9 deletions
+37
-9
csrc/quantization/cutlass_w8a8/broadcast_load_epilogue_c3x.hpp
...quantization/cutlass_w8a8/broadcast_load_epilogue_c3x.hpp
+37
-9
No files found.
csrc/quantization/cutlass_w8a8/broadcast_load_epilogue_c3x.hpp
View file @
55712941
...
...
@@ -328,20 +328,36 @@ struct Sm90ColOrScalarBroadcast {
return
EmptyProducerLoadCallbacks
{};
}
template
<
class
GTensor
,
class
RTensor
>
template
<
class
GTensor
,
class
RTensor
,
class
CTensor
,
class
ProblemShape
>
struct
ConsumerStoreCallbacks
:
EmptyConsumerStoreCallbacks
{
CUTLASS_DEVICE
ConsumerStoreCallbacks
(
GTensor
&&
tCgCol
,
RTensor
&&
tCrCol
,
Params
const
&
params
)
:
tCgCol
(
cute
::
forward
<
GTensor
>
(
tCgCol
)),
ConsumerStoreCallbacks
(
GTensor
&&
tCgCol
,
RTensor
&&
tCrCol
,
CTensor
&&
tCcCol
,
ProblemShape
problem_shape
,
Params
const
&
params
)
:
tCgCol
(
cute
::
forward
<
GTensor
>
(
tCgCol
)),
tCrCol
(
cute
::
forward
<
RTensor
>
(
tCrCol
)),
tCcCol
(
cute
::
forward
<
CTensor
>
(
tCcCol
)),
m
(
get
<
0
>
(
problem_shape
)),
params
(
params
)
{}
GTensor
tCgCol
;
// (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
RTensor
tCrCol
;
// (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
RTensor
tCrCol
;
CTensor
tCcCol
;
// (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
Params
const
&
params
;
int
m
;
CUTLASS_DEVICE
void
begin
()
{
Tensor
pred
=
make_tensor
<
bool
>
(
shape
(
tCgCol
));
CUTLASS_PRAGMA_UNROLL
for
(
int
i
=
0
;
i
<
size
(
pred
);
++
i
)
{
pred
(
i
)
=
get
<
0
>
(
tCcCol
(
i
))
<
m
;
}
if
(
!
params
.
col_broadcast
)
{
fill
(
tCrCol
,
*
(
params
.
ptr_col
));
return
;
...
...
@@ -349,7 +365,7 @@ struct Sm90ColOrScalarBroadcast {
// Filter so we don't issue redundant copies over stride-0 modes
// (only works if 0-strides are in same location, which is by construction)
copy_
aligned
(
filter
(
tCgCol
),
filter
(
tCrCol
));
copy_
if
(
pred
,
filter
(
tCgCol
),
filter
(
tCrCol
));
}
template
<
typename
ElementAccumulator
,
int
FragmentSize
>
...
...
@@ -381,8 +397,20 @@ struct Sm90ColOrScalarBroadcast {
mCol
,
args
.
tile_shape_mnk
,
args
.
tile_coord_mnkl
,
args
.
epi_tile
,
args
.
tiled_copy
,
args
.
thread_idx
);
Tensor
tCrCol
=
make_tensor_like
(
tCgCol
);
// (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
return
ConsumerStoreCallbacks
<
decltype
(
tCgCol
),
decltype
(
tCrCol
)
>
(
cute
::
move
(
tCgCol
),
cute
::
move
(
tCrCol
),
params
);
// Generate an identity tensor matching the shape of the global tensor and
// partition the same way, this will be used to generate the predicate
// tensor for loading
Tensor
cCol
=
make_identity_tensor
(
mCol
.
shape
());
Tensor
tCcCol
=
sm90_partition_for_epilogue
<
ReferenceSrc
>
(
// (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
cCol
,
args
.
tile_shape_mnk
,
args
.
tile_coord_mnkl
,
args
.
epi_tile
,
args
.
tiled_copy
,
args
.
thread_idx
);
return
ConsumerStoreCallbacks
(
cute
::
move
(
tCgCol
),
cute
::
move
(
tCrCol
),
cute
::
move
(
tCcCol
),
args
.
problem_shape_mnkl
,
params
);
}
};
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment