Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
b07e21a4
Commit
b07e21a4
authored
Oct 31, 2024
by
mtgu0705
Browse files
Enable pipeline v1 & v2 for int8 quant
parent
80b46cae
Changes
8
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
1242 additions
and
87 deletions
+1242
-87
example/65_gemm_multiply_multiply/gemm_fp16int4_b_scale.cpp
example/65_gemm_multiply_multiply/gemm_fp16int4_b_scale.cpp
+1
-1
example/65_gemm_multiply_multiply/gemm_fp16int8_b_scale.cpp
example/65_gemm_multiply_multiply/gemm_fp16int8_b_scale.cpp
+1
-1
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_scale_selector.hpp
...block/blockwise_gemm_pipeline_xdlops_b_scale_selector.hpp
+75
-22
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v1_b_scale.hpp
...n/gpu/block/blockwise_gemm_pipeline_xdlops_v1_b_scale.hpp
+437
-0
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v2_b_scale.hpp
...n/gpu/block/blockwise_gemm_pipeline_xdlops_v2_b_scale.hpp
+654
-0
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3_b_scale.hpp
...n/gpu/block/blockwise_gemm_pipeline_xdlops_v3_b_scale.hpp
+1
-1
include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3_b_scale.hpp
...e/impl/device_gemm_multiple_d_xdl_cshuffle_v3_b_scale.hpp
+2
-0
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_b_scale.hpp
...pu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_b_scale.hpp
+71
-62
No files found.
example/65_gemm_multiply_multiply/gemm_fp16int4_b_scale.cpp
View file @
b07e21a4
...
@@ -80,7 +80,7 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultiD_BScale_X
...
@@ -80,7 +80,7 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultiD_BScale_X
// S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0,
// S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0,
1
,
2
,
S
<
1
,
32
,
1
,
8
>
,
S
<
8
,
8
,
1
>
,
1
,
2
,
S
<
1
,
32
,
1
,
8
>
,
S
<
8
,
8
,
1
>
,
// ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, FP8>;
// ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, FP8>;
ck
::
BlockGemmPipelineScheduler
::
Int
e
rwave
,
ck
::
BlockGemmPipelineVersion
::
v1
,
false
,
PermuteB
>
;
ck
::
BlockGemmPipelineScheduler
::
Intr
a
wave
,
ck
::
BlockGemmPipelineVersion
::
v1
,
false
,
PermuteB
>
;
// clang-format on
// clang-format on
...
...
example/65_gemm_multiply_multiply/gemm_fp16int8_b_scale.cpp
View file @
b07e21a4
...
@@ -81,7 +81,7 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultiD_BScale_X
...
@@ -81,7 +81,7 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultiD_BScale_X
// S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0,
// S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0,
1
,
2
,
S
<
1
,
32
,
1
,
8
>
,
S
<
8
,
8
,
1
>
,
1
,
2
,
S
<
1
,
32
,
1
,
8
>
,
S
<
8
,
8
,
1
>
,
// ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, FP8>;
// ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, FP8>;
ck
::
BlockGemmPipelineScheduler
::
Intrawave
,
ck
::
BlockGemmPipelineVersion
::
v
1
>
;
ck
::
BlockGemmPipelineScheduler
::
Intrawave
,
ck
::
BlockGemmPipelineVersion
::
v
3
>
;
// clang-format on
// clang-format on
template
<
typename
IntType
>
template
<
typename
IntType
>
...
...
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_scale_selector.hpp
View file @
b07e21a4
...
@@ -3,8 +3,8 @@
...
@@ -3,8 +3,8 @@
#pragma once
#pragma once
//
#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v1_
a
b_scale.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v1_b_scale.hpp"
//
#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v2_
a
b_scale.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v2_b_scale.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3_b_scale.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3_b_scale.hpp"
namespace
ck
{
namespace
ck
{
...
@@ -39,26 +39,79 @@ template <BlockGemmPipelineVersion BlkGemmPipelineVer,
...
@@ -39,26 +39,79 @@ template <BlockGemmPipelineVersion BlkGemmPipelineVer,
index_t
KPack
>
index_t
KPack
>
constexpr
auto
BlockGemmBScalePipeline_Selector
()
constexpr
auto
BlockGemmBScalePipeline_Selector
()
{
{
return
BlockwiseGemmXdlops_pipeline_v3_b_scale
<
BlkGemmPipeSche
,
if
constexpr
(
BlkGemmPipelineVer
==
BlockGemmPipelineVersion
::
v1
)
BlockSize
,
{
ADataType
,
return
BlockwiseGemmXdlops_pipeline_v1_b_scale
<
BlkGemmPipeSche
,
BDataType
,
BlockSize
,
ComputeDataType
,
ADataType
,
AccDataType
,
BDataType
,
ATileDesc
,
ComputeDataType
,
BTileDesc
,
AccDataType
,
AMmaTileDesc
,
ATileDesc
,
BMmaTileDesc
,
BTileDesc
,
ABlockTransferSrcScalarPerVector
,
AMmaTileDesc
,
BBlockTransferSrcScalarPerVector
,
BMmaTileDesc
,
MPerBlock
,
ABlockTransferSrcScalarPerVector
,
NPerBlock
,
BBlockTransferSrcScalarPerVector
,
KPerBlock
,
MPerBlock
,
MPerXDL
,
NPerBlock
,
NPerXDL
,
KPerBlock
,
MRepeat
,
MPerXDL
,
NRepeat
,
NPerXDL
,
KPack
>
{};
MRepeat
,
NRepeat
,
KPack
>
{};
}
else
if
constexpr
(
BlkGemmPipelineVer
==
BlockGemmPipelineVersion
::
v2
)
{
return
BlockwiseGemmXdlops_pipeline_v2_b_scale
<
BlkGemmPipeSche
,
BlockSize
,
ADataType
,
BDataType
,
ComputeDataType
,
AccDataType
,
ATileDesc
,
BTileDesc
,
AMmaTileDesc
,
BMmaTileDesc
,
ABlockTransferSrcScalarPerVector
,
BBlockTransferSrcScalarPerVector
,
MPerBlock
,
NPerBlock
,
KPerBlock
,
MPerXDL
,
NPerXDL
,
MRepeat
,
NRepeat
,
KPack
>
{};
}
else
if
constexpr
(
BlkGemmPipelineVer
==
BlockGemmPipelineVersion
::
v3
)
{
return
BlockwiseGemmXdlops_pipeline_v3_b_scale
<
BlkGemmPipeSche
,
BlockSize
,
ADataType
,
BDataType
,
ComputeDataType
,
AccDataType
,
ATileDesc
,
BTileDesc
,
AMmaTileDesc
,
BMmaTileDesc
,
ABlockTransferSrcScalarPerVector
,
BBlockTransferSrcScalarPerVector
,
MPerBlock
,
NPerBlock
,
KPerBlock
,
MPerXDL
,
NPerXDL
,
MRepeat
,
NRepeat
,
KPack
>
{};
}
else
{
std
::
cerr
<<
"BlockGemmPipeline configuration is not available"
<<
std
::
endl
;
}
}
}
}
// namespace ck
}
// namespace ck
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v1_b_scale.hpp
0 → 100644
View file @
b07e21a4
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_base.hpp"
namespace
ck
{
// Naive pipeline with lowest resource request per WGP
// GlobalPrefetchStages: 1
// LocalPreFillStages: 1
// LocalPreFetchStages: 0
// LocalSharedMemoryBuffer: 1
template
<
BlockGemmPipelineScheduler
BlkGemmPipelineVer
,
index_t
BlockSize
,
typename
ADataType
,
typename
BDataType
,
typename
ComputeDataType
,
typename
AccDataType
,
typename
ATileDesc
,
typename
BTileDesc
,
typename
AMmaTileDesc
,
typename
BMmaTileDesc
,
index_t
ABlockTransferSrcScalarPerVector
,
index_t
BBlockTransferSrcScalarPerVector
,
index_t
MPerBlock
,
index_t
NPerBlock
,
index_t
KPerBlock
,
index_t
MPerXDL
,
index_t
NPerXDL
,
index_t
MRepeat
,
index_t
NRepeat
,
index_t
KPacks
>
struct
BlockwiseGemmXdlops_pipeline_v1_b_scale
{
};
template
<
index_t
BlockSize
,
typename
ADataType
,
typename
BDataType
,
typename
ComputeDataType
,
typename
AccDataType
,
typename
ATileDesc
,
typename
BTileDesc
,
typename
AMmaTileDesc
,
typename
BMmaTileDesc
,
index_t
ABlockTransferSrcScalarPerVector
,
index_t
BBlockTransferSrcScalarPerVector
,
index_t
MPerBlock
,
index_t
NPerBlock
,
index_t
KPerBlock
,
index_t
MPerXDL
,
index_t
NPerXDL
,
index_t
MRepeat
,
index_t
NRepeat
,
index_t
KPack
// ,bool TransposeC //disable transposec right now...
>
struct
BlockwiseGemmXdlops_pipeline_v1_b_scale
<
BlockGemmPipelineScheduler
::
Intrawave
,
BlockSize
,
ADataType
,
BDataType
,
ComputeDataType
,
AccDataType
,
ATileDesc
,
BTileDesc
,
AMmaTileDesc
,
BMmaTileDesc
,
ABlockTransferSrcScalarPerVector
,
BBlockTransferSrcScalarPerVector
,
MPerBlock
,
NPerBlock
,
KPerBlock
,
MPerXDL
,
NPerXDL
,
MRepeat
,
NRepeat
,
KPack
>
:
BlockwiseGemmXdlops_pipeline_base
<
BlockSize
,
ADataType
,
BDataType
,
ComputeDataType
,
AccDataType
,
ATileDesc
,
BTileDesc
,
AMmaTileDesc
,
BMmaTileDesc
,
ABlockTransferSrcScalarPerVector
,
BBlockTransferSrcScalarPerVector
,
MPerBlock
,
NPerBlock
,
KPerBlock
,
MPerXDL
,
NPerXDL
,
MRepeat
,
NRepeat
,
KPack
>
{
using
Base
=
BlockwiseGemmXdlops_pipeline_base
<
BlockSize
,
ADataType
,
BDataType
,
ComputeDataType
,
AccDataType
,
ATileDesc
,
BTileDesc
,
AMmaTileDesc
,
BMmaTileDesc
,
ABlockTransferSrcScalarPerVector
,
BBlockTransferSrcScalarPerVector
,
MPerBlock
,
NPerBlock
,
KPerBlock
,
MPerXDL
,
NPerXDL
,
MRepeat
,
NRepeat
,
KPack
>
;
using
Base
::
I0
;
using
Base
::
KRepeat
;
using
Base
::
xdlops_gemm
;
using
Base
::
CalculateCThreadOriginDataIndex
;
using
Base
::
CalculateCThreadOriginDataIndex8D
;
using
Base
::
GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2
;
using
Base
::
GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
;
using
Base
::
GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4
;
using
Base
::
GetCThreadBuffer
;
using
Base
::
GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2
;
using
Base
::
GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
;
using
Base
::
GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4
;
using
Base
::
MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2
;
using
Base
::
MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
;
using
Base
::
a_block_desc_m0_m1_m2_k
;
using
Base
::
b_block_desc_n0_n1_n2_k
;
using
Base
::
AMmaKStride
;
using
Base
::
BMmaKStride
;
static
constexpr
index_t
PrefetchStages
=
1
;
static
constexpr
index_t
PrefillStages
=
1
;
static
constexpr
index_t
GlobalBufferNum
=
1
;
__host__
static
constexpr
bool
BlockHasHotloop
(
index_t
num_loop
)
{
return
num_loop
>
PrefetchStages
;
}
__host__
static
constexpr
TailNumber
BlockLoopTailNum
(
index_t
num_loop
)
{
ignore
=
num_loop
;
return
TailNumber
::
Full
;
}
template
<
bool
HasMainLoop
,
TailNumber
TailNum
,
typename
AGridDesc
,
typename
ABlockDesc
,
typename
ABlockTransfer
,
typename
AGridBuffer
,
typename
ABlockBuffer
,
typename
ABlockTransferStep
,
typename
BGridDesc
,
typename
BBlockDesc
,
typename
BBlockTransfer
,
typename
BGridBuffer
,
typename
BBlockBuffer
,
typename
BBlockTransferStep
,
typename
CThreadBuffer
,
// typename AScaleGridBuffer,
// typename AScaleGridDesc,
// typename AScaleThreadDesc,
// typename AScaleThreadTransfer,
// typename AScaleThreadTransferStep,
typename
BScaleGridBuffer
,
typename
BScaleGridDesc
,
typename
BScaleThreadDesc
,
typename
BScaleThreadTransfer
,
typename
BScaleThreadTransferStep
>
__device__
void
Run
(
// ABlockCopy
const
AGridDesc
&
a_grid_desc
,
const
ABlockDesc
&
a_block_desc
,
ABlockTransfer
&
a_blockwise_copy
,
const
AGridBuffer
&
a_grid_buf
,
ABlockBuffer
&
a_block_buf
,
const
ABlockTransferStep
&
a_block_copy_step
,
// BBlockCopy
const
BGridDesc
&
b_grid_desc
,
const
BBlockDesc
&
b_block_desc
,
BBlockTransfer
&
b_blockwise_copy
,
const
BGridBuffer
&
b_grid_buf
,
BBlockBuffer
&
b_block_buf
,
const
BBlockTransferStep
&
b_block_copy_step
,
// CThread
CThreadBuffer
&
c_thread_buf
,
// // AScaleThreadCopy
// const AScaleGridDesc& a_scale_grid_desc,
// const AScaleThreadDesc& a_scale_thread_desc,
// AScaleThreadTransfer& a_scale_thread_copy,
// const AScaleGridBuffer& a_scale_grid_buf,
// const AScaleThreadTransferStep& a_scale_thread_copy_step,
// BScaleThreadCopy
const
BScaleGridDesc
&
b_scale_grid_desc
,
const
BScaleThreadDesc
&
b_scale_thread_desc
,
BScaleThreadTransfer
&
b_scale_thread_copy
,
const
BScaleGridBuffer
&
b_scale_grid_buf
,
const
BScaleThreadTransferStep
&
b_scale_thread_copy_step
,
// num_loop
index_t
num_loop
,
index_t
num_loop_per_scale
)
const
{
// assume kperblock = scaleblockk
ignore
=
num_loop_per_scale
;
auto
a_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
ComputeDataType
>
(
a_thread_desc_
.
GetElementSpaceSize
());
auto
b_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
ComputeDataType
>
(
b_thread_desc_
.
GetElementSpaceSize
());
// auto a_scale_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, AccDataType>(
// a_scale_thread_desc.GetElementSpaceSize());
// auto b_scale_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, AccDataType>(
// b_scale_thread_desc.GetElementSpaceSize());
auto
b_scale_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
ComputeDataType
>
(
b_scale_thread_desc
.
GetElementSpaceSize
());
// Global prefetch 1
a_blockwise_copy
.
RunRead
(
a_grid_desc
,
a_grid_buf
);
b_blockwise_copy
.
RunRead
(
b_grid_desc
,
b_grid_buf
);
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc
,
a_block_copy_step
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc
,
b_block_copy_step
);
// a_scale_thread_copy.Run(a_scale_grid_desc,
// a_scale_grid_buf,
// a_scale_thread_desc,
// make_tuple(I0, I0),
// a_scale_thread_buf);
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
b_scale_thread_copy
.
Run
(
b_scale_grid_desc
,
b_scale_grid_buf
,
b_scale_thread_desc
,
make_tuple
(
n0
,
I0
),
b_scale_thread_buf
);
b_scale_thread_copy
.
MoveSrcSliceWindow
(
b_scale_grid_desc
,
b_scale_thread_copy_step
.
At
(
Number
<
0
>
{}));
});
b_scale_thread_copy
.
MoveSrcSliceWindow
(
b_scale_grid_desc
,
b_scale_thread_copy_step
.
At
(
Number
<
1
>
{}));
// a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc, a_scale_thread_copy_step);
// b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, b_scale_thread_copy_step);
// Local prefill 1
a_blockwise_copy
.
RunWrite
(
a_block_desc
,
a_block_buf
);
b_blockwise_copy
.
RunWrite
(
b_block_desc
,
b_block_buf
);
// Initialize C
c_thread_buf
.
Clear
();
auto
c_thread_buf_per_scale
=
remove_cvref_t
<
decltype
(
c_thread_buf
)
>
();
// main body
if
constexpr
(
HasMainLoop
)
{
index_t
i
=
0
;
do
{
// -------------------------------------------------------------------------------------------
a_blockwise_copy
.
RunRead
(
a_grid_desc
,
a_grid_buf
);
b_blockwise_copy
.
RunRead
(
b_grid_desc
,
b_grid_buf
);
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc
,
a_block_copy_step
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc
,
b_block_copy_step
);
block_sync_lds
();
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
k
)
{
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
a_thread_copy_
.
Run
(
a_block_desc_m0_m1_m2_k
,
make_tuple
(
m0
,
I0
,
I0
,
Number
<
k
*
AMmaKStride
>
{}),
a_block_buf
,
a_thread_desc_
,
make_tuple
(
m0
,
I0
,
k
,
I0
),
a_thread_buf
);
});
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
b_thread_copy_
.
Run
(
b_block_desc_n0_n1_n2_k
,
make_tuple
(
n0
,
I0
,
I0
,
Number
<
k
*
BMmaKStride
>
{}),
b_block_buf
,
b_thread_desc_
,
make_tuple
(
n0
,
I0
,
k
,
I0
),
b_thread_buf
);
});
});
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
c_thread_buf_per_scale
.
Clear
();
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
k0
)
{
vector_type
<
ComputeDataType
,
KPack
>
a_thread_vec
;
vector_type
<
ComputeDataType
,
KPack
>
b_thread_vec
;
static_for
<
0
,
KPack
,
1
>
{}([
&
](
auto
ik
)
{
a_thread_vec
.
template
AsType
<
ComputeDataType
>()(
ik
)
=
a_thread_buf
[
Number
<
a_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
I0
,
k0
,
ik
))
>
{}];
b_thread_vec
.
template
AsType
<
ComputeDataType
>()(
ik
)
=
b_thread_buf
[
Number
<
b_thread_desc_
.
CalculateOffset
(
make_tuple
(
n0
,
I0
,
k0
,
ik
))
>
{}];
});
using
mfma_input_type
=
typename
vector_type
<
ComputeDataType
,
xdlops_gemm
.
K1PerXdlops
>::
type
;
xdlops_gemm
.
template
Run
(
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf_per_scale
.
GetVectorTypeReference
(
I0
));
});
static_for
<
0
,
xdlops_gemm
.
GetRegSizePerXdlops
(),
1
>
{}([
&
](
auto
t
)
{
constexpr
index_t
c_offset
=
c_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
n0
,
t
));
c_thread_buf
(
Number
<
c_offset
>
{})
+=
c_thread_buf_per_scale
[
Number
<
t
>
{}]
*
// type_convert<AccDataType>(a_scale_thread_buf[I0]) *
type_convert
<
AccDataType
>
(
b_scale_thread_buf
[
n0
]);
});
});
});
// a_scale_thread_copy.Run(a_scale_grid_desc,
// a_scale_grid_buf,
// a_scale_thread_desc,
// make_tuple(I0, I0),
// a_scale_thread_buf);
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
b_scale_thread_copy
.
Run
(
b_scale_grid_desc
,
b_scale_grid_buf
,
b_scale_thread_desc
,
make_tuple
(
n0
,
I0
),
b_scale_thread_buf
);
b_scale_thread_copy
.
MoveSrcSliceWindow
(
b_scale_grid_desc
,
b_scale_thread_copy_step
.
At
(
Number
<
0
>
{}));
});
b_scale_thread_copy
.
MoveSrcSliceWindow
(
b_scale_grid_desc
,
b_scale_thread_copy_step
.
At
(
Number
<
1
>
{}));
// a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc, a_scale_thread_copy_step);
// b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, b_scale_thread_copy_step);
block_sync_lds
();
a_blockwise_copy
.
RunWrite
(
a_block_desc
,
a_block_buf
);
b_blockwise_copy
.
RunWrite
(
b_block_desc
,
b_block_buf
);
i
+=
1
;
}
while
(
i
<
(
num_loop
-
1
));
}
// tail
if
constexpr
(
TailNum
==
TailNumber
::
Full
)
{
block_sync_lds
();
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
k
)
{
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
a_thread_copy_
.
Run
(
a_block_desc_m0_m1_m2_k
,
make_tuple
(
m0
,
I0
,
I0
,
Number
<
k
*
AMmaKStride
>
{}),
a_block_buf
,
a_thread_desc_
,
make_tuple
(
m0
,
I0
,
k
,
I0
),
a_thread_buf
);
});
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
b_thread_copy_
.
Run
(
b_block_desc_n0_n1_n2_k
,
make_tuple
(
n0
,
I0
,
I0
,
Number
<
k
*
BMmaKStride
>
{}),
b_block_buf
,
b_thread_desc_
,
make_tuple
(
n0
,
I0
,
k
,
I0
),
b_thread_buf
);
});
});
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
c_thread_buf_per_scale
.
Clear
();
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
k0
)
{
vector_type
<
ComputeDataType
,
KPack
>
a_thread_vec
;
vector_type
<
ComputeDataType
,
KPack
>
b_thread_vec
;
static_for
<
0
,
KPack
,
1
>
{}([
&
](
auto
ik
)
{
a_thread_vec
.
template
AsType
<
ComputeDataType
>()(
ik
)
=
a_thread_buf
[
Number
<
a_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
I0
,
k0
,
ik
))
>
{}];
b_thread_vec
.
template
AsType
<
ComputeDataType
>()(
ik
)
=
b_thread_buf
[
Number
<
b_thread_desc_
.
CalculateOffset
(
make_tuple
(
n0
,
I0
,
k0
,
ik
))
>
{}];
});
using
mfma_input_type
=
typename
vector_type
<
ComputeDataType
,
xdlops_gemm
.
K1PerXdlops
>::
type
;
xdlops_gemm
.
template
Run
(
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf_per_scale
.
GetVectorTypeReference
(
I0
));
});
static_for
<
0
,
xdlops_gemm
.
GetRegSizePerXdlops
(),
1
>
{}([
&
](
auto
t
)
{
constexpr
index_t
c_offset
=
c_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
n0
,
t
));
c_thread_buf
(
Number
<
c_offset
>
{})
+=
c_thread_buf_per_scale
[
Number
<
t
>
{}]
*
// type_convert<AccDataType>(a_scale_thread_buf[I0]) *
type_convert
<
AccDataType
>
(
b_scale_thread_buf
[
n0
]);
});
});
});
}
}
protected:
using
Base
::
a_thread_copy_
;
using
Base
::
a_thread_desc_
;
using
Base
::
b_thread_copy_
;
using
Base
::
b_thread_desc_
;
using
Base
::
c_thread_desc_
;
};
}
// namespace ck
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v2_b_scale.hpp
0 → 100644
View file @
b07e21a4
This diff is collapsed.
Click to expand it.
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3_b_scale.hpp
View file @
b07e21a4
...
@@ -330,7 +330,7 @@ struct BlockwiseGemmXdlops_pipeline_v3_b_scale<BlockGemmPipelineScheduler::Intra
...
@@ -330,7 +330,7 @@ struct BlockwiseGemmXdlops_pipeline_v3_b_scale<BlockGemmPipelineScheduler::Intra
// a_scale_thread_desc.GetElementSpaceSize());
// a_scale_thread_desc.GetElementSpaceSize());
// auto b_scale_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, AccDataType>(
// auto b_scale_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, AccDataType>(
// b_scale_thread_desc.GetElementSpaceSize());
// b_scale_thread_desc.GetElementSpaceSize());
auto
b_scale_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
ck
::
half_t
>
(
auto
b_scale_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
ComputeDataType
>
(
b_scale_thread_desc
.
GetElementSpaceSize
());
b_scale_thread_desc
.
GetElementSpaceSize
());
// Global prefetch 1
// Global prefetch 1
...
...
include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3_b_scale.hpp
View file @
b07e21a4
...
@@ -148,6 +148,8 @@ struct DeviceGemmMultiD_BScale_Xdl_CShuffle_V3
...
@@ -148,6 +148,8 @@ struct DeviceGemmMultiD_BScale_Xdl_CShuffle_V3
BlkGemmPipelineVer
,
BlkGemmPipelineVer
,
ComputeTypeA
,
ComputeTypeA
,
ComputeTypeB
,
ComputeTypeB
,
PermuteA
,
PermuteB
,
LDSTypeA
,
LDSTypeA
,
LDSTypeB
>
;
LDSTypeB
>
;
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_b_scale.hpp
View file @
b07e21a4
...
@@ -111,6 +111,8 @@ template <typename ALayout,
...
@@ -111,6 +111,8 @@ template <typename ALayout,
BlockGemmPipelineVersion
BlkGemmPipelineVer
=
BlockGemmPipelineVersion
::
v1
,
BlockGemmPipelineVersion
BlkGemmPipelineVer
=
BlockGemmPipelineVersion
::
v1
,
typename
ComputeTypeA
=
CDataType
,
typename
ComputeTypeA
=
CDataType
,
typename
ComputeTypeB
=
ComputeTypeA
,
typename
ComputeTypeB
=
ComputeTypeA
,
bool
PermuteA
=
false
,
bool
PermuteB
=
false
,
typename
LDSTypeA
=
ADataType
,
typename
LDSTypeA
=
ADataType
,
typename
LDSTypeB
=
BDataType
>
typename
LDSTypeB
=
BDataType
>
struct
GridwiseGemmMultiD_BScale_xdl_cshuffle_v3
struct
GridwiseGemmMultiD_BScale_xdl_cshuffle_v3
...
@@ -158,7 +160,7 @@ struct GridwiseGemmMultiD_BScale_xdl_cshuffle_v3
...
@@ -158,7 +160,7 @@ struct GridwiseGemmMultiD_BScale_xdl_cshuffle_v3
using
ThisThreadBlock
=
ThisThreadBlock
<
BlockSize
>
;
using
ThisThreadBlock
=
ThisThreadBlock
<
BlockSize
>
;
static
constexpr
index_t
APackedSize
=
[]()
{
static
constexpr
index_t
APackedSize
=
[]()
{
if
constexpr
(
is_same_v
<
remove_cvref_t
<
ADataType
>
,
pk_i4_t
>
)
if
constexpr
(
is_same_v
<
remove_cvref_t
<
ADataType
>
,
pk_i4_t
>
)
return
2
;
return
2
;
else
else
...
@@ -340,6 +342,10 @@ struct GridwiseGemmMultiD_BScale_xdl_cshuffle_v3
...
@@ -340,6 +342,10 @@ struct GridwiseGemmMultiD_BScale_xdl_cshuffle_v3
using
GemmSpecialization
=
tensor_operation
::
device
::
GemmSpecialization
;
using
GemmSpecialization
=
tensor_operation
::
device
::
GemmSpecialization
;
static_assert
(
!
(
is_same_v
<
remove_cvref_t
<
ADataType
>
,
pk_i4_t
>
&&
GemmSpec
!=
GemmSpecialization
::
Default
),
"pk_i4_t does not support padding"
);
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
NKPadding
||
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
NKPadding
||
GemmSpec
==
GemmSpecialization
::
MNKPadding
)
GemmSpec
==
GemmSpecialization
::
MNKPadding
)
{
{
...
@@ -394,15 +400,39 @@ struct GridwiseGemmMultiD_BScale_xdl_cshuffle_v3
...
@@ -394,15 +400,39 @@ struct GridwiseGemmMultiD_BScale_xdl_cshuffle_v3
}
}
else
else
{
{
// not pad N or K
if
constexpr
(
!
PermuteB
)
const
auto
b_grid_desc_bk0_n_bk1
=
transform_tensor_descriptor
(
{
b_grid_desc_nraw_kraw
,
// not pad N or K
make_tuple
(
make_unmerge_transform
(
make_tuple
(
BK0
,
BK1Value
)),
const
auto
b_grid_desc_bk0_n_bk1
=
transform_tensor_descriptor
(
make_pass_through_transform
(
N
)),
b_grid_desc_nraw_kraw
,
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
make_unmerge_transform
(
make_tuple
(
BK0
,
BK1Value
)),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
make_pass_through_transform
(
N
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
return
b_grid_desc_bk0_n_bk1
;
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
return
b_grid_desc_bk0_n_bk1
;
}
else
{
// Weight Tile Permute
constexpr
index_t
BK01
=
KPerBlock
/
BK1Value
;
// const index_t BK00 = BK0 / BK01;
const
index_t
BK0_
=
StrideB
/
BK1Value
;
const
index_t
BK00
=
BK0_
/
BK01
;
const
auto
b_grid_desc_bk00_n_bk01_bk1_permute
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
BK00
,
N
,
BK01
,
BK1Value
));
const
auto
b_grid_desc_bk0_n_bk1_permute
=
transform_tensor_descriptor
(
b_grid_desc_bk00_n_bk01_bk1_permute
,
make_tuple
(
make_merge_transform
(
make_tuple
(
BK00
,
BK01
)),
make_pass_through_transform
(
make_tuple
(
N
)),
make_pass_through_transform
(
BK1Value
)),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}));
return
b_grid_desc_bk0_n_bk1_permute
;
}
}
}
}
}
...
@@ -439,6 +469,13 @@ struct GridwiseGemmMultiD_BScale_xdl_cshuffle_v3
...
@@ -439,6 +469,13 @@ struct GridwiseGemmMultiD_BScale_xdl_cshuffle_v3
}
}
}();
}();
// // pad M and N
// return transform_tensor_descriptor(c_grid_desc_mraw_nraw,
// make_tuple(make_right_pad_transform(M, MPad - M),
// make_right_pad_transform(N, NPad - N)),
// make_tuple(Sequence<0>{}, Sequence<1>{}),
// make_tuple(Sequence<0>{}, Sequence<1>{}));
using
GemmSpecialization
=
tensor_operation
::
device
::
GemmSpecialization
;
using
GemmSpecialization
=
tensor_operation
::
device
::
GemmSpecialization
;
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
MNPadding
||
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
MNPadding
||
...
@@ -630,20 +667,28 @@ struct GridwiseGemmMultiD_BScale_xdl_cshuffle_v3
...
@@ -630,20 +667,28 @@ struct GridwiseGemmMultiD_BScale_xdl_cshuffle_v3
{
{
if
constexpr
(
is_same_v
<
tensor_layout
::
gemm
::
RowMajor
,
ALayout
>
)
if
constexpr
(
is_same_v
<
tensor_layout
::
gemm
::
RowMajor
,
ALayout
>
)
{
{
a_k_split_offset
=
blockIdx
.
z
*
karg
.
KRead
;
a_k_split_offset
=
blockIdx
.
z
*
karg
.
KRead
/
APackedSize
;
}
}
else
if
constexpr
(
is_same_v
<
tensor_layout
::
gemm
::
ColumnMajor
,
ALayout
>
)
else
if
constexpr
(
is_same_v
<
tensor_layout
::
gemm
::
ColumnMajor
,
ALayout
>
)
{
{
a_k_split_offset
=
blockIdx
.
z
*
karg
.
KRead
*
karg
.
M
;
a_k_split_offset
=
blockIdx
.
z
*
karg
.
KRead
*
karg
.
StrideA
;
}
}
if
constexpr
(
is_same_v
<
tensor_layout
::
gemm
::
RowMajor
,
BLayout
>
)
if
constexpr
(
is_same_v
<
tensor_layout
::
gemm
::
RowMajor
,
BLayout
>
)
{
{
b_k_split_offset
=
blockIdx
.
z
*
karg
.
KRead
*
karg
.
N
;
b_k_split_offset
=
blockIdx
.
z
*
karg
.
KRead
*
karg
.
StrideB
;
}
}
else
if
constexpr
(
is_same_v
<
tensor_layout
::
gemm
::
ColumnMajor
,
BLayout
>
)
else
if
constexpr
(
is_same_v
<
tensor_layout
::
gemm
::
ColumnMajor
,
BLayout
>
)
{
{
b_k_split_offset
=
blockIdx
.
z
*
karg
.
KRead
;
if
constexpr
(
!
PermuteB
)
{
b_k_split_offset
=
blockIdx
.
z
*
karg
.
KRead
/
BPackedSize
;
}
else
{
const
int
k0_offset
=
karg
.
KRead
*
karg
.
N
;
b_k_split_offset
=
blockIdx
.
z
*
k0_offset
/
BPackedSize
;
}
}
}
if
(
blockIdx
.
z
<
static_cast
<
uint32_t
>
(
karg
.
KBatch
-
1
))
if
(
blockIdx
.
z
<
static_cast
<
uint32_t
>
(
karg
.
KBatch
-
1
))
...
@@ -673,9 +718,9 @@ struct GridwiseGemmMultiD_BScale_xdl_cshuffle_v3
...
@@ -673,9 +718,9 @@ struct GridwiseGemmMultiD_BScale_xdl_cshuffle_v3
// in some cases.
// in some cases.
else
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
RowMajor
,
ALayout
>::
value
)
else
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
RowMajor
,
ALayout
>::
value
)
{
{
constexpr
auto
MLdsLayer
=
32
*
4
/
KPerBlock
/
sizeof
(
LDSTypeA
)
<
1
constexpr
auto
MLdsLayer
=
32
*
4
/
KPerBlock
/
sizeof
(
LDSTypeA
)
/
APackedSize
<
1
?
1
?
1
:
32
*
4
/
KPerBlock
/
sizeof
(
LDSTypeA
);
:
32
*
4
/
KPerBlock
/
sizeof
(
LDSTypeA
)
/
APackedSize
;
constexpr
auto
a_lds_block_desc
=
make_naive_tensor_descriptor
(
constexpr
auto
a_lds_block_desc
=
make_naive_tensor_descriptor
(
make_tuple
(
make_tuple
(
AK0Number
*
Number
<
MLdsLayer
>
{},
Number
<
MPerBlock
/
MLdsLayer
>
{},
AK1Number
),
AK0Number
*
Number
<
MLdsLayer
>
{},
Number
<
MPerBlock
/
MLdsLayer
>
{},
AK1Number
),
...
@@ -809,10 +854,10 @@ struct GridwiseGemmMultiD_BScale_xdl_cshuffle_v3
...
@@ -809,10 +854,10 @@ struct GridwiseGemmMultiD_BScale_xdl_cshuffle_v3
else
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
ColumnMajor
,
BLayout
>::
value
)
else
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
ColumnMajor
,
BLayout
>::
value
)
{
{
// NLdsLayer * K0 as logical Bank
// NLdsLayer * K0 as logical Bank
constexpr
auto
NLdsLayer
=
32
*
4
/
KPerBlock
/
sizeof
(
LDSTypeB
)
<
1
constexpr
auto
NLdsLayer
=
32
*
4
/
KPerBlock
/
sizeof
(
LDSTypeB
)
/
BPackedSize
<
1
?
1
?
1
:
32
*
4
/
KPerBlock
/
sizeof
(
LDSTypeB
);
:
32
*
4
/
KPerBlock
/
sizeof
(
LDSTypeB
)
/
BPackedSize
;
;
constexpr
auto
b_lds_block_desc
=
make_naive_tensor_descriptor
(
constexpr
auto
b_lds_block_desc
=
make_naive_tensor_descriptor
(
make_tuple
(
make_tuple
(
BK0Number
*
Number
<
NLdsLayer
>
{},
Number
<
NPerBlock
/
NLdsLayer
>
{},
BK1Number
),
BK0Number
*
Number
<
NLdsLayer
>
{},
Number
<
NPerBlock
/
NLdsLayer
>
{},
BK1Number
),
...
@@ -994,8 +1039,8 @@ struct GridwiseGemmMultiD_BScale_xdl_cshuffle_v3
...
@@ -994,8 +1039,8 @@ struct GridwiseGemmMultiD_BScale_xdl_cshuffle_v3
constexpr
auto
c_block_size
=
constexpr
auto
c_block_size
=
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
.
GetElementSpaceSize
();
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
.
GetElementSpaceSize
();
return
math
::
max
((
a_block_space_size_aligned
*
sizeof
(
LDSTypeA
)
+
return
math
::
max
((
a_block_space_size_aligned
*
sizeof
(
LDSTypeA
)
/
APackedSize
+
b_block_space_size_aligned
*
sizeof
(
LDSTypeB
)),
b_block_space_size_aligned
*
sizeof
(
LDSTypeB
)
/
BPackedSize
),
c_block_size
*
sizeof
(
CShuffleDataType
));
c_block_size
*
sizeof
(
CShuffleDataType
));
}
}
...
@@ -1009,7 +1054,8 @@ struct GridwiseGemmMultiD_BScale_xdl_cshuffle_v3
...
@@ -1009,7 +1054,8 @@ struct GridwiseGemmMultiD_BScale_xdl_cshuffle_v3
if
constexpr
(
!
(
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
MPadding
||
if
constexpr
(
!
(
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
MPadding
||
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
MNPadding
||
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
MNPadding
||
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
MKPadding
||
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
MKPadding
||
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
MNKPadding
))
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
MNKPadding
)
&&
!
(
is_same
<
tensor_layout
::
gemm
::
RowMajor
,
ALayout
>::
value
))
{
{
if
(
!
(
karg
.
M
%
MPerBlock
==
0
))
if
(
!
(
karg
.
M
%
MPerBlock
==
0
))
{
{
...
@@ -1026,7 +1072,8 @@ struct GridwiseGemmMultiD_BScale_xdl_cshuffle_v3
...
@@ -1026,7 +1072,8 @@ struct GridwiseGemmMultiD_BScale_xdl_cshuffle_v3
if
constexpr
(
!
(
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
NPadding
||
if
constexpr
(
!
(
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
NPadding
||
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
MNPadding
||
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
MNPadding
||
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
NKPadding
||
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
NKPadding
||
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
MNKPadding
))
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
MNKPadding
)
&&
(
is_same
<
tensor_layout
::
gemm
::
RowMajor
,
BLayout
>::
value
))
{
{
if
(
!
(
karg
.
N
%
NPerBlock
==
0
))
if
(
!
(
karg
.
N
%
NPerBlock
==
0
))
{
{
...
@@ -1228,10 +1275,6 @@ struct GridwiseGemmMultiD_BScale_xdl_cshuffle_v3
...
@@ -1228,10 +1275,6 @@ struct GridwiseGemmMultiD_BScale_xdl_cshuffle_v3
const
auto
c_grid_desc_m_n
=
MakeCGridDescriptor_M_N
<
CLayout
>
(
const
auto
c_grid_desc_m_n
=
MakeCGridDescriptor_M_N
<
CLayout
>
(
problem
.
M
,
problem
.
MPadded
,
problem
.
N
,
problem
.
NPadded
,
problem
.
StrideC
);
problem
.
M
,
problem
.
MPadded
,
problem
.
N
,
problem
.
NPadded
,
problem
.
StrideC
);
// const auto a_scale_grid_desc_am_ak = make_naive_tensor_descriptor(
// make_tuple(math::integer_divide_ceil(problem.M, ScaleBlockM),
// math::integer_divide_ceil(problem.K, ScaleBlockK)),
// make_tuple(math::integer_divide_ceil(problem.K, ScaleBlockK), 1));
const
auto
b_scale_grid_desc_bn_ak
=
make_naive_tensor_descriptor
(
const
auto
b_scale_grid_desc_bn_ak
=
make_naive_tensor_descriptor
(
make_tuple
(
math
::
integer_divide_ceil
(
problem
.
N
,
ScaleBlockN
),
make_tuple
(
math
::
integer_divide_ceil
(
problem
.
N
,
ScaleBlockN
),
math
::
integer_divide_ceil
(
problem
.
K
,
ScaleBlockK
)),
math
::
integer_divide_ceil
(
problem
.
K
,
ScaleBlockK
)),
...
@@ -1247,10 +1290,6 @@ struct GridwiseGemmMultiD_BScale_xdl_cshuffle_v3
...
@@ -1247,10 +1290,6 @@ struct GridwiseGemmMultiD_BScale_xdl_cshuffle_v3
p_b_grid
,
b_grid_desc_bk0_n_bk1
.
GetElementSpaceSize
());
p_b_grid
,
b_grid_desc_bk0_n_bk1
.
GetElementSpaceSize
());
auto
c_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
auto
c_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_c_grid
,
c_grid_desc_mblock_mperblock_nblock_nperblock
.
GetElementSpaceSize
());
p_c_grid
,
c_grid_desc_mblock_mperblock_nblock_nperblock
.
GetElementSpaceSize
());
// const auto a_scale_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
// p_a_scale_grid, a_scale_grid_desc_am_ak.GetElementSpaceSize());
const
auto
b_scale_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
const
auto
b_scale_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_b_scale_grid
,
b_scale_grid_desc_bn_ak
.
GetElementSpaceSize
());
p_b_scale_grid
,
b_scale_grid_desc_bn_ak
.
GetElementSpaceSize
());
...
@@ -1358,7 +1397,7 @@ struct GridwiseGemmMultiD_BScale_xdl_cshuffle_v3
...
@@ -1358,7 +1397,7 @@ struct GridwiseGemmMultiD_BScale_xdl_cshuffle_v3
auto
b_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
auto
b_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
LDSTypeB
*>
(
p_shared
)
+
static_cast
<
LDSTypeB
*>
(
p_shared
)
+
a_block_space_size_aligned
*
sizeof
(
LDSTypeA
)
/
sizeof
(
LDSTypeB
),
a_block_space_size_aligned
*
sizeof
(
LDSTypeA
)
/
sizeof
(
LDSTypeB
)
/
APackedSize
,
b_block_desc_bk0_n_bk1
.
GetElementSpaceSize
());
b_block_desc_bk0_n_bk1
.
GetElementSpaceSize
());
constexpr
auto
a_block_slice_copy_step
=
make_multi_index
(
KPerBlock
/
AK1Number
,
0
,
0
);
constexpr
auto
a_block_slice_copy_step
=
make_multi_index
(
KPerBlock
/
AK1Number
,
0
,
0
);
...
@@ -1377,37 +1416,14 @@ struct GridwiseGemmMultiD_BScale_xdl_cshuffle_v3
...
@@ -1377,37 +1416,14 @@ struct GridwiseGemmMultiD_BScale_xdl_cshuffle_v3
const
index_t
ScaleSliceSizeN
=
NXdlPerWave
;
const
index_t
ScaleSliceSizeN
=
NXdlPerWave
;
const
index_t
ScaleSliceSizeK
=
1
;
const
index_t
ScaleSliceSizeK
=
1
;
// constexpr auto a_scale_thread_desc = make_naive_tensor_descriptor_packed(
// make_tuple(Number<ScaleSliceSizeM>{}, Number<ScaleSliceSizeK>{}));
constexpr
auto
b_scale_thread_desc
=
make_naive_tensor_descriptor_packed
(
constexpr
auto
b_scale_thread_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
ScaleSliceSizeN
>
{},
Number
<
ScaleSliceSizeK
>
{}));
make_tuple
(
Number
<
ScaleSliceSizeN
>
{},
Number
<
ScaleSliceSizeK
>
{}));
// auto a_scale_thread_copy =
// ThreadwiseTensorSliceTransfer_v2<AScaleType,
// AScaleType,
// decltype(a_scale_grid_desc_am_ak),
// decltype(a_scale_thread_desc),
// Sequence<ScaleSliceSizeM, ScaleSliceSizeK>,
// Sequence<0, 1>,
// 1,
// 1,
// 1,
// false>(
// a_scale_grid_desc_am_ak, make_multi_index(block_m_id * MPerBlock / ScaleBlockM,
// 0));
constexpr
index_t
NWaves
=
NPerBlock
/
(
NXdlPerWave
*
NPerXdl
);
constexpr
index_t
NWaves
=
NPerBlock
/
(
NXdlPerWave
*
NPerXdl
);
auto
b_thread_offset
=
auto
b_thread_offset
=
get_thread_local_1d_id
()
%
NPerXdl
+
(
get_thread_local_1d_id
()
/
64
)
%
NWaves
*
NPerXdl
;
get_thread_local_1d_id
()
%
NPerXdl
+
(
get_thread_local_1d_id
()
/
64
)
%
NWaves
*
NPerXdl
;
// __syncthreads();
// if(blockIdx.x==0)
// {
// printf("ThreadIdx: %d, b_thread_offset: %d\n", get_thread_local_1d_id(), b_thread_offset);
// }
auto
b_scale_thread_copy
=
auto
b_scale_thread_copy
=
ThreadwiseTensorSliceTransfer_v2
<
BScaleType
,
ThreadwiseTensorSliceTransfer_v2
<
BScaleType
,
BScaleType
,
BScaleType
,
...
@@ -1422,7 +1438,6 @@ struct GridwiseGemmMultiD_BScale_xdl_cshuffle_v3
...
@@ -1422,7 +1438,6 @@ struct GridwiseGemmMultiD_BScale_xdl_cshuffle_v3
b_scale_grid_desc_bn_ak
,
b_scale_grid_desc_bn_ak
,
make_multi_index
(
block_n_id
*
NPerBlock
/
ScaleBlockN
+
b_thread_offset
,
0
));
make_multi_index
(
block_n_id
*
NPerBlock
/
ScaleBlockN
+
b_thread_offset
,
0
));
// constexpr auto a_scale_thread_slice_copy_step = make_multi_index(0, 1);
constexpr
auto
b_scale_thread_slice_copy_step
=
constexpr
auto
b_scale_thread_slice_copy_step
=
make_tuple
(
make_multi_index
(
NWaves
*
NPerXdl
,
0
),
make_multi_index
(
-
NPerBlock
,
1
));
make_tuple
(
make_multi_index
(
NWaves
*
NPerXdl
,
0
),
make_multi_index
(
-
NPerBlock
,
1
));
...
@@ -1443,12 +1458,6 @@ struct GridwiseGemmMultiD_BScale_xdl_cshuffle_v3
...
@@ -1443,12 +1458,6 @@ struct GridwiseGemmMultiD_BScale_xdl_cshuffle_v3
b_block_slice_copy_step
,
b_block_slice_copy_step
,
c_thread_buf
,
c_thread_buf
,
// a_scale_grid_desc_am_ak,
// a_scale_thread_desc,
// a_scale_thread_copy,
// a_scale_grid_buf,
// a_scale_thread_slice_copy_step,
b_scale_grid_desc_bn_ak
,
b_scale_grid_desc_bn_ak
,
b_scale_thread_desc
,
b_scale_thread_desc
,
b_scale_thread_copy
,
b_scale_thread_copy
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment