Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
edc494df
Commit
edc494df
authored
Aug 04, 2022
by
Anthony Chang
Browse files
clang-format
parent
00331ee4
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
127 additions
and
128 deletions
+127
-128
example/31_batched_gemm_gemm/batched_gemm_gemm_xdl_fp16.cpp
example/31_batched_gemm_gemm/batched_gemm_gemm_xdl_fp16.cpp
+45
-39
include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp
...e/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp
+6
-8
include/ck/tensor_operation/gpu/device/device_batched_gemm_gemm_xdl_cshuffle.hpp
...tion/gpu/device/device_batched_gemm_gemm_xdl_cshuffle.hpp
+46
-62
include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_gemm_xdl_cshuffle_v1.hpp
...n/gpu/grid/gridwise_batched_gemm_gemm_xdl_cshuffle_v1.hpp
+10
-6
include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp
...operation/gpu/thread/threadwise_tensor_slice_transfer.hpp
+9
-6
include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp
include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp
+5
-3
library/include/ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp
...reference_tensor_operation/cpu/reference_batched_gemm.hpp
+2
-1
library/include/ck/library/utility/host_tensor_generator.hpp
library/include/ck/library/utility/host_tensor_generator.hpp
+4
-3
No files found.
example/31_batched_gemm_gemm/batched_gemm_gemm_xdl_fp16.cpp
View file @
edc494df
...
...
@@ -48,10 +48,10 @@ using B0Layout = Col;
using
B1Layout
=
Row
;
using
CLayout
=
Row
;
using
AElementOp
=
PassThrough
;
using
AElementOp
=
PassThrough
;
using
B0ElementOp
=
PassThrough
;
using
B1ElementOp
=
PassThrough
;
using
CElementOp
=
PassThrough
;
using
CElementOp
=
PassThrough
;
static
constexpr
auto
GemmDefault
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
Default
;
...
...
@@ -113,14 +113,19 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceBatchedGemmGemm_X
8
>
;
// CShuffleBlockTransferScalarPerVector_NPerBlock
using
ReferenceGemm0Instance
=
ck
::
tensor_operation
::
host
::
ReferenceBatchedGemm
<
ADataType
,
B0DataType
,
ADataType
,
AccDataType
,
AElementOp
,
B0ElementOp
,
CElementOp
>
;
using
ReferenceGemm1Instance
=
ck
::
tensor_operation
::
host
::
ReferenceBatchedGemm
<
ADataType
,
B1DataType
,
CDataType
,
AccDataType
,
AElementOp
,
B1ElementOp
,
CElementOp
>
;
B0DataType
,
ADataType
,
AccDataType
,
AElementOp
,
B0ElementOp
,
CElementOp
>
;
using
ReferenceGemm1Instance
=
ck
::
tensor_operation
::
host
::
ReferenceBatchedGemm
<
ADataType
,
B1DataType
,
CDataType
,
AccDataType
,
AElementOp
,
B1ElementOp
,
CElementOp
>
;
int
main
(
int
argc
,
char
*
argv
[])
{
...
...
@@ -179,15 +184,15 @@ int main(int argc, char* argv[])
BatchCount
=
std
::
stoi
(
argv
[
8
]);
StrideA
=
std
::
stoi
(
argv
[
9
]);
StrideA
=
std
::
stoi
(
argv
[
9
]);
StrideB0
=
std
::
stoi
(
argv
[
10
]);
StrideB1
=
std
::
stoi
(
argv
[
11
]);
StrideC
=
std
::
stoi
(
argv
[
12
]);
StrideC
=
std
::
stoi
(
argv
[
12
]);
BatchStrideA
=
std
::
stoi
(
argv
[
13
]);
BatchStrideA
=
std
::
stoi
(
argv
[
13
]);
BatchStrideB0
=
std
::
stoi
(
argv
[
14
]);
BatchStrideB1
=
std
::
stoi
(
argv
[
15
]);
BatchStrideC
=
std
::
stoi
(
argv
[
16
]);
BatchStrideC
=
std
::
stoi
(
argv
[
16
]);
}
else
{
...
...
@@ -282,35 +287,36 @@ int main(int argc, char* argv[])
b0_g_k_n_device_buf
.
ToDevice
(
b0_g_k_n
.
mData
.
data
());
b1_g_n_o_device_buf
.
ToDevice
(
b1_g_n_o
.
mData
.
data
());
auto
a_element_op
=
AElementOp
{};
auto
a_element_op
=
AElementOp
{};
auto
b0_element_op
=
B0ElementOp
{};
auto
b1_element_op
=
B1ElementOp
{};
auto
c_element_op
=
CElementOp
{};
auto
c_element_op
=
CElementOp
{};
// do GEMM
auto
gemm
=
DeviceGemmInstance
{};
auto
invoker
=
gemm
.
MakeInvoker
();
auto
argument
=
gemm
.
MakeArgument
(
static_cast
<
ADataType
*>
(
a_g_m_k_device_buf
.
GetDeviceBuffer
()),
static_cast
<
B0DataType
*>
(
b0_g_k_n_device_buf
.
GetDeviceBuffer
()),
static_cast
<
B1DataType
*>
(
b1_g_n_o_device_buf
.
GetDeviceBuffer
()),
static_cast
<
CDataType
*>
(
c_g_m_o_device_buf
.
GetDeviceBuffer
()),
M
,
N
,
K
,
O
,
BatchCount
,
StrideA
,
StrideB0
,
StrideB1
,
StrideC
,
BatchStrideA
,
BatchStrideB0
,
BatchStrideB1
,
BatchStrideC
,
a_element_op
,
b0_element_op
,
b1_element_op
,
c_element_op
);
auto
gemm
=
DeviceGemmInstance
{};
auto
invoker
=
gemm
.
MakeInvoker
();
auto
argument
=
gemm
.
MakeArgument
(
static_cast
<
ADataType
*>
(
a_g_m_k_device_buf
.
GetDeviceBuffer
()),
static_cast
<
B0DataType
*>
(
b0_g_k_n_device_buf
.
GetDeviceBuffer
()),
static_cast
<
B1DataType
*>
(
b1_g_n_o_device_buf
.
GetDeviceBuffer
()),
static_cast
<
CDataType
*>
(
c_g_m_o_device_buf
.
GetDeviceBuffer
()),
M
,
N
,
K
,
O
,
BatchCount
,
StrideA
,
StrideB0
,
StrideB1
,
StrideC
,
BatchStrideA
,
BatchStrideB0
,
BatchStrideB1
,
BatchStrideC
,
a_element_op
,
b0_element_op
,
b1_element_op
,
c_element_op
);
if
(
!
gemm
.
IsSupportedArgument
(
argument
))
{
...
...
include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp
View file @
edc494df
...
...
@@ -35,8 +35,8 @@ MakeGemmMmaTileDescriptor_MN0_MN1_MN2_K(const TileDesc_K0_MN_K1&)
return
transform_tensor_descriptor
(
TileDesc_K0_MN_K1
{},
make_tuple
(
make_merge_transform_v3_division_mod
(
make_tuple
(
Number
<
K0
>
{},
Number
<
K1
>
{})),
make_unmerge_transform
(
make_tuple
(
Number
<
MNXdlPerWave
>
{},
Number
<
MNWaves
>
{},
Number
<
MNPerXdl
>
{}))),
make_unmerge_transform
(
make_tuple
(
Number
<
MNXdlPerWave
>
{},
Number
<
MNWaves
>
{},
Number
<
MNPerXdl
>
{}))),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
3
>
{},
Sequence
<
0
,
1
,
2
>
{}));
}
...
...
@@ -694,7 +694,7 @@ struct BlockwiseGemmXdlops_v2
template
<
index_t
m0
,
index_t
n0
,
index_t
xdlops_i
,
index_t
blk_i
>
__device__
static
auto
CalculateCThreadOriginDataIndex
(
Number
<
m0
>
,
Number
<
n0
>
,
Number
<
xdlops_i
>
,
Number
<
blk_i
>
)
CalculateCThreadOriginDataIndex
(
Number
<
m0
>
,
Number
<
n0
>
,
Number
<
xdlops_i
>
,
Number
<
blk_i
>
)
{
const
auto
wave_idx
=
GetWaveIdx
();
...
...
@@ -723,9 +723,8 @@ struct BlockwiseGemmXdlops_v2
using
Tuple4
=
decltype
(
CalculateAThreadOriginDataIndex
());
__host__
__device__
BlockwiseGemmXdlops_v2
(
Tuple4
a_origin
=
CalculateAThreadOriginDataIndex
(),
Tuple4
b_origin
=
CalculateBThreadOriginDataIndex
())
__host__
__device__
BlockwiseGemmXdlops_v2
(
Tuple4
a_origin
=
CalculateAThreadOriginDataIndex
(),
Tuple4
b_origin
=
CalculateBThreadOriginDataIndex
())
:
a_thread_copy_
(
a_origin
),
b_thread_copy_
(
b_origin
)
{
static_assert
(
AMmaTileDesc
::
IsKnownAtCompileTime
()
&&
BMmaTileDesc
::
IsKnownAtCompileTime
(),
...
...
@@ -738,8 +737,7 @@ struct BlockwiseGemmXdlops_v2
"wrong!"
);
}
__host__
__device__
BlockwiseGemmXdlops_v2
(
const
BlockwiseGemmXdlops_v2
&
other
)
__host__
__device__
BlockwiseGemmXdlops_v2
(
const
BlockwiseGemmXdlops_v2
&
other
)
:
a_thread_copy_
(
other
.
a_origin
),
b_thread_copy_
(
other
.
b_origin
)
{
}
...
...
include/ck/tensor_operation/gpu/device/device_batched_gemm_gemm_xdl_cshuffle.hpp
View file @
edc494df
...
...
@@ -38,22 +38,23 @@ __global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
#endif
kernel_gemm_gemm_xdl_cshuffle_v1
(
const
FloatAB
*
__restrict__
p_a_grid
,
const
FloatAB
*
__restrict__
p_b_grid
,
const
FloatAB
*
__restrict__
p_b1_grid
,
FloatC
*
__restrict__
p_c_grid
,
const
AElementwiseOperation
a_element_op
,
const
BElementwiseOperation
b_element_op
,
const
B1ElementwiseOperation
b1_element_op
,
const
CElementwiseOperation
c_element_op
,
const
AGridDesc_AK0_M_AK1
a_grid_desc_ak0_m_ak1
,
const
BGridDesc_BK0_N_BK1
b_grid_desc_bk0_n_bk1
,
const
B1GridDesc_BK0_N_BK1
b1_grid_desc_bk0_n_bk1
,
const
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock
,
const
Block2CTileMap
block_2_ctile_map
,
const
index_t
batch_count
,
const
ComputeBasePtrOfStridedBatch
compute_base_ptr_of_batch
)
kernel_gemm_gemm_xdl_cshuffle_v1
(
const
FloatAB
*
__restrict__
p_a_grid
,
const
FloatAB
*
__restrict__
p_b_grid
,
const
FloatAB
*
__restrict__
p_b1_grid
,
FloatC
*
__restrict__
p_c_grid
,
const
AElementwiseOperation
a_element_op
,
const
BElementwiseOperation
b_element_op
,
const
B1ElementwiseOperation
b1_element_op
,
const
CElementwiseOperation
c_element_op
,
const
AGridDesc_AK0_M_AK1
a_grid_desc_ak0_m_ak1
,
const
BGridDesc_BK0_N_BK1
b_grid_desc_bk0_n_bk1
,
const
B1GridDesc_BK0_N_BK1
b1_grid_desc_bk0_n_bk1
,
const
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock
,
const
Block2CTileMap
block_2_ctile_map
,
const
index_t
batch_count
,
const
ComputeBasePtrOfStridedBatch
compute_base_ptr_of_batch
)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
...
...
@@ -162,17 +163,17 @@ template <typename ALayout,
index_t
CShuffleBlockTransferScalarPerVector_NPerBlock
,
LoopScheduler
LoopSched
=
LoopScheduler
::
Default
>
struct
DeviceBatchedGemmGemm_Xdl_CShuffle
:
public
DeviceBatchedGemmGemm
<
ALayout
,
BLayout
,
B1Layout
,
CLayout
,
ADataType
,
BDataType
,
B1DataType
,
CDataType
,
AElementwiseOperation
,
BElementwiseOperation
,
B1ElementwiseOperation
,
CElementwiseOperation
>
BLayout
,
B1Layout
,
CLayout
,
ADataType
,
BDataType
,
B1DataType
,
CDataType
,
AElementwiseOperation
,
BElementwiseOperation
,
B1ElementwiseOperation
,
CElementwiseOperation
>
{
using
DeviceOp
=
DeviceBatchedGemmGemm_Xdl_CShuffle
;
...
...
@@ -405,12 +406,12 @@ struct DeviceBatchedGemmGemm_Xdl_CShuffle : public DeviceBatchedGemmGemm<ALayout
{
const
auto
B1K0
=
KRaw
/
B1K1
;
const
auto
b1_grid_desc_bk0_n_bk1
=
transform_tensor_descriptor
(
b1_grid_desc_nraw_kraw
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
B1K0
,
B1K1
)),
make_pass_through_transform
(
NRaw
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
const
auto
b1_grid_desc_bk0_n_bk1
=
transform_tensor_descriptor
(
b1_grid_desc_nraw_kraw
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
B1K0
,
B1K1
)),
make_pass_through_transform
(
NRaw
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
return
b1_grid_desc_bk0_n_bk1
;
}
...
...
@@ -426,16 +427,15 @@ struct DeviceBatchedGemmGemm_Xdl_CShuffle : public DeviceBatchedGemmGemm<ALayout
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
const
auto
b1_grid_desc_bk0_n_bk1
=
transform_tensor_descriptor
(
b1_grid_desc_n_k
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
B1K0
,
B1K1
)),
make_pass_through_transform
(
N
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
const
auto
b1_grid_desc_bk0_n_bk1
=
transform_tensor_descriptor
(
b1_grid_desc_n_k
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
B1K0
,
B1K1
)),
make_pass_through_transform
(
N
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
return
b1_grid_desc_bk0_n_bk1
;
}
}
static
auto
MakeCGridDescriptor_M_N
(
index_t
MRaw
,
index_t
NRaw
,
index_t
StrideC
)
...
...
@@ -537,9 +537,9 @@ struct DeviceBatchedGemmGemm_Xdl_CShuffle : public DeviceBatchedGemmGemm<ALayout
};
using
AGridDesc_AK0_M_AK1
=
decltype
(
MakeAGridDescriptor_AK0_M_AK1
(
1
,
1
,
1
));
using
BGridDesc_BK0_N_BK1
=
decltype
(
MakeBGridDescriptor_BK0_N_BK1
(
1
,
1
,
1
));
using
BGridDesc_BK0_N_BK1
=
decltype
(
MakeBGridDescriptor_BK0_N_BK1
(
1
,
1
,
1
));
using
B1GridDesc_BK0_N_BK1
=
decltype
(
MakeB1GridDescriptor_BK0_N_BK1
(
1
,
1
,
1
));
using
CGridDesc_M_N
=
decltype
(
MakeCGridDescriptor_M_N
(
1
,
1
,
1
));
using
CGridDesc_M_N
=
decltype
(
MakeCGridDescriptor_M_N
(
1
,
1
,
1
));
// GridwiseGemm
using
GridwiseGemm
=
GridwiseBatchedGemmGemm_Xdl_CShuffle
<
...
...
@@ -809,26 +809,10 @@ struct DeviceBatchedGemmGemm_Xdl_CShuffle : public DeviceBatchedGemmGemm<ALayout
B1ElementwiseOperation
b1_element_op
,
CElementwiseOperation
c_element_op
)
{
return
Argument
{
p_a
,
p_b
,
p_b1
,
p_c
,
MRaw
,
NRaw
,
KRaw
,
Gemm1NRaw
,
Batch
,
StrideA
,
StrideB
,
StrideB1
,
StrideC
,
BatchStrideA
,
BatchStrideB
,
BatchStrideB1
,
BatchStrideC
,
a_element_op
,
b_element_op
,
b1_element_op
,
return
Argument
{
p_a
,
p_b
,
p_b1
,
p_c
,
MRaw
,
NRaw
,
KRaw
,
Gemm1NRaw
,
Batch
,
StrideA
,
StrideB
,
StrideB1
,
StrideC
,
BatchStrideA
,
BatchStrideB
,
BatchStrideB1
,
BatchStrideC
,
a_element_op
,
b_element_op
,
b1_element_op
,
c_element_op
};
}
...
...
include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_gemm_xdl_cshuffle_v1.hpp
View file @
edc494df
...
...
@@ -181,8 +181,8 @@ struct GridwiseBatchedGemmGemm_Xdl_CShuffle
__host__
__device__
static
constexpr
index_t
GetSharedMemoryNumberOfByte
()
{
// LDS allocation for A and B: be careful of alignment
constexpr
auto
a_block_desc_ak0_m_ak1
=
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1
();
constexpr
auto
b_block_desc_bk0_n_bk1
=
GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1
();
constexpr
auto
a_block_desc_ak0_m_ak1
=
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1
();
constexpr
auto
b_block_desc_bk0_n_bk1
=
GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1
();
constexpr
auto
b1_block_desc_bk0_n_bk1
=
GetB1BlockDescriptor_BK0PerBlock_NPerBlock_BK1
();
// lds max alignment
...
...
@@ -207,7 +207,8 @@ struct GridwiseBatchedGemmGemm_Xdl_CShuffle
constexpr
auto
c_block_size
=
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
.
GetElementSpaceSize
();
return
math
::
max
((
a_block_space_size_aligned
+
b_block_space_size_aligned
)
*
sizeof
(
FloatAB
),
return
math
::
max
((
a_block_space_size_aligned
+
b_block_space_size_aligned
)
*
sizeof
(
FloatAB
),
c_block_size
*
sizeof
(
FloatCShuffle
));
}
...
...
@@ -234,7 +235,8 @@ struct GridwiseBatchedGemmGemm_Xdl_CShuffle
return
false
;
}
if
(
!
(
M
%
MPerBlock
==
0
&&
N
%
NPerBlock
==
0
&&
K
%
KPerBlock
==
0
&&
Gemm1N
%
Gemm1NPerBlock
==
0
))
if
(
!
(
M
%
MPerBlock
==
0
&&
N
%
NPerBlock
==
0
&&
K
%
KPerBlock
==
0
&&
Gemm1N
%
Gemm1NPerBlock
==
0
))
{
return
false
;
}
...
...
@@ -472,8 +474,10 @@ struct GridwiseBatchedGemmGemm_Xdl_CShuffle
constexpr
auto
a_block_slice_copy_step
=
make_multi_index
(
KPerBlock
/
AK1
,
0
,
0
);
constexpr
auto
b_block_slice_copy_step
=
make_multi_index
(
KPerBlock
/
BK1
,
0
,
0
);
const
auto
a_block_reset_copy_step
=
make_multi_index
(
-
a_grid_desc_ak0_m_ak1
.
GetLength
(
I0
),
0
,
0
);
const
auto
b_block_reset_copy_step
=
make_multi_index
(
-
b_grid_desc_bk0_n_bk1
.
GetLength
(
I0
),
NPerBlock
,
0
);
const
auto
a_block_reset_copy_step
=
make_multi_index
(
-
a_grid_desc_ak0_m_ak1
.
GetLength
(
I0
),
0
,
0
);
const
auto
b_block_reset_copy_step
=
make_multi_index
(
-
b_grid_desc_bk0_n_bk1
.
GetLength
(
I0
),
NPerBlock
,
0
);
// gridwise GEMM pipeline
// Only supports LoopScheduler::Default
...
...
include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp
View file @
edc494df
...
...
@@ -1154,11 +1154,11 @@ struct ThreadwiseTensorSliceTransfer_v4
{
static_for
<
0
,
SrcScalarPerVector
,
1
>
{}([
&
](
auto
i
)
{
constexpr
index_t
src_offset
=
src_desc
.
CalculateOffset
(
src_ref_to_origin_disp_idx
+
data_to_origin_disp_idx
+
i
*
src_scalar_step_in_vector
);
src_ref_to_origin_disp_idx
+
data_to_origin_disp_idx
+
i
*
src_scalar_step_in_vector
);
// apply type convert
src_tmp_vector
.
template
AsType
<
SrcData
>()(
i
)
=
src_buf
[
Number
<
src_offset
>
{}];
src_tmp_vector
.
template
AsType
<
SrcData
>()(
i
)
=
src_buf
[
Number
<
src_offset
>
{}];
});
}
// copy data from src_tmp_vector to dst_tmp_vector (data cast data from SrcData to
...
...
@@ -1206,7 +1206,8 @@ template <typename SrcData,
typename
DimAccessOrder
,
index_t
DstVectorDim
,
index_t
DstScalarPerVector
,
typename
enable_if
<
SrcDesc
::
IsKnownAtCompileTime
()
&&
DstDesc
::
IsKnownAtCompileTime
(),
bool
>
::
type
=
false
>
typename
enable_if
<
SrcDesc
::
IsKnownAtCompileTime
()
&&
DstDesc
::
IsKnownAtCompileTime
(),
bool
>
::
type
=
false
>
struct
ThreadwiseTensorSliceTransfer_StaticToStatic
{
static
constexpr
index_t
nDim
=
SliceLengths
::
Size
();
...
...
@@ -1222,7 +1223,10 @@ struct ThreadwiseTensorSliceTransfer_StaticToStatic
"wrong! Not divisible"
);
}
template
<
typename
SrcSliceOriginIdx
,
typename
DstSliceOriginIdx
,
typename
SrcBuffer
,
typename
DstBuffer
>
template
<
typename
SrcSliceOriginIdx
,
typename
DstSliceOriginIdx
,
typename
SrcBuffer
,
typename
DstBuffer
>
__device__
void
Run
(
const
SrcDesc
&
,
const
SrcSliceOriginIdx
&
,
const
SrcBuffer
&
src_buf
,
...
...
@@ -1277,7 +1281,6 @@ struct ThreadwiseTensorSliceTransfer_StaticToStatic
});
});
}
};
}
// namespace ck
include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp
View file @
edc494df
...
...
@@ -739,13 +739,15 @@ struct XdlopsGemm
"base base_type must be double, float, half, bfloat16, and int8_t!"
);
static_for
<
0
,
KPack
/
mfma_instr
.
k_per_blk
,
1
>
{}([
&
](
auto
k
)
{
if
constexpr
(
!
TransposeC
)
if
constexpr
(
!
TransposeC
)
{
mfma_instr
.
template
run
<
MPerXdlops
,
NPerXdlops
>(
p_a_wave
[
k
],
p_b_wave
[
k
],
p_c_thread
);
mfma_instr
.
template
run
<
MPerXdlops
,
NPerXdlops
>(
p_a_wave
[
k
],
p_b_wave
[
k
],
p_c_thread
);
}
else
{
mfma_instr
.
template
run
<
MPerXdlops
,
NPerXdlops
>(
p_b_wave
[
k
],
p_a_wave
[
k
],
p_c_thread
);
mfma_instr
.
template
run
<
MPerXdlops
,
NPerXdlops
>(
p_b_wave
[
k
],
p_a_wave
[
k
],
p_c_thread
);
}
});
}
...
...
library/include/ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp
View file @
edc494df
...
...
@@ -69,7 +69,8 @@ struct ReferenceBatchedGemm : public device::BaseOperator
arg
.
a_element_op_
(
v_a
,
arg
.
a_g_m_k_
(
g
,
m
,
k
));
arg
.
b_element_op_
(
v_b
,
arg
.
b_g_k_n_
(
g
,
k
,
n
));
v_acc
+=
ck
::
type_convert
<
AccDataType
>
(
v_a
)
*
ck
::
type_convert
<
AccDataType
>
(
v_b
);
v_acc
+=
ck
::
type_convert
<
AccDataType
>
(
v_a
)
*
ck
::
type_convert
<
AccDataType
>
(
v_b
);
}
AccDataType
v_c
;
...
...
library/include/ck/library/utility/host_tensor_generator.hpp
View file @
edc494df
...
...
@@ -161,9 +161,10 @@ struct GeneratorTensor_Diagonal
T
operator
()(
Ts
...
Xs
)
const
{
std
::
array
<
ck
::
index_t
,
sizeof
...(
Ts
)
>
dims
=
{{
static_cast
<
ck
::
index_t
>
(
Xs
)...}};
size_t
start_dim
=
dims
.
size
()
-
NumEffectiveDim
;
bool
pred
=
true
;
for
(
size_t
i
=
start_dim
+
1
;
i
<
dims
.
size
();
i
++
)
{
size_t
start_dim
=
dims
.
size
()
-
NumEffectiveDim
;
bool
pred
=
true
;
for
(
size_t
i
=
start_dim
+
1
;
i
<
dims
.
size
();
i
++
)
{
pred
&=
(
dims
[
start_dim
]
==
dims
[
i
]);
}
return
pred
?
value
:
T
{
0
};
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment