Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
a4b08b57
Commit
a4b08b57
authored
Oct 05, 2023
by
Adam Osewski
Browse files
Generalize kernel to grouped_gemm and add more test cases.
parent
7316bd15
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
330 additions
and
119 deletions
+330
-119
test/work_scheduling/test_strided_reduction_tile_loop.cpp
test/work_scheduling/test_strided_reduction_tile_loop.cpp
+330
-119
No files found.
test/work_scheduling/test_strided_reduction_tile_loop.cpp
View file @
a4b08b57
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include <memory>
#include <vector>
#include <gtest/gtest.h>
#include <gtest/gtest.h>
...
@@ -19,103 +20,167 @@
...
@@ -19,103 +20,167 @@
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
namespace
{
using
namespace
ck
;
using
namespace
ck
;
namespace
{
struct
GemmArgDesc
{
GemmArgDesc
(
index_t
M_
,
index_t
N_
,
index_t
K_
,
const
float
*
p_A_
,
const
float
*
p_B_
,
float
*
p_C_
,
index_t
tile_count_
)
:
M
{
M_
},
N
{
N_
},
K
{
K_
},
p_A
{
p_A_
},
p_B
{
p_B_
},
p_C
{
p_C_
},
tile_count
{
tile_count_
}
{
}
index_t
M
;
index_t
N
;
index_t
K
;
const
float
*
p_A
;
const
float
*
p_B
;
float
*
p_C
;
index_t
tile_count
;
};
template
<
index_t
MPerBlock
,
index_t
NPerBlock
,
index_t
KPerBlock
>
template
<
index_t
MPerBlock
,
index_t
NPerBlock
,
index_t
KPerBlock
>
__global__
void
gemm_naive_strided_tile_loop_reduce
(
index_t
M
,
__global__
void
grouped_gemm_naive_strided_tile_loop_reduce
(
const
GemmArgDesc
*
p_gemm_descs
,
index_t
N
,
float
*
p_workspace
,
index_t
K
,
uint32_t
*
p_flags
,
const
float
*
p_A
,
index_t
tile_count
,
const
float
*
p_B
,
index_t
k_batch
)
float
*
p_C
,
float
*
p_workspace
,
uint32_t
*
p_flags
,
index_t
tile_count
,
index_t
k_batch
)
{
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
StridedReductionTileLoop
work_scheduler
{
tile_count
,
p_flags
};
StridedReductionTileLoop
work_scheduler
{
tile_count
,
p_flags
};
const
auto
c_grid_desc_m_n
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
M
,
N
));
BlockToCTileMap_LinearKSplit
<
MPerBlock
,
NPerBlock
>
b2c_tile_map
(
c_grid_desc_m_n
,
k_batch
);
float
partial_result
=
0.
f
;
constexpr
auto
I0
=
Number
<
0
>
{};
// early exit if no work.
constexpr
auto
I1
=
Number
<
1
>
{};
if
(
work_scheduler
.
tile_id_
>=
tile_count
)
return
;
// Assume MK-KN-MN data layout
index_t
group_id
=
0
;
const
index_t
stride_a
=
K
;
index_t
offset
=
0
;
const
index_t
stride_b
=
N
;
index_t
grid_size_grp
=
p_gemm_descs
[
group_id
].
tile_count
;
const
index_t
stride_c
=
N
;
// K is the contiguous dim in memory, as well as fastest changing dim in B2C mapping.
index_t
gemm_tile_id_start
=
0
;
const
auto
block_work_idx
=
b2c_tile_map
.
CalculateBottomIndex
(
work_scheduler
.
tile_id_
);
index_t
gemm_tile_id_end
=
grid_size_grp
;
const
index_t
block_m_id
=
__builtin_amdgcn_readfirstlane
(
block_work_idx
[
I0
]);
const
index_t
block_n_id
=
__builtin_amdgcn_readfirstlane
(
block_work_idx
[
I1
]);
do
do
{
{
const
index_t
k_batch_id
=
__builtin_amdgcn_readfirstlane
(
b2c_tile_map
.
GetTileKIdx
());
// Find corresponding GEMM group for out tile
while
(
!
(
work_scheduler
.
tile_id_
>=
gemm_tile_id_start
&&
work_scheduler
.
tile_id_
<
gemm_tile_id_end
))
{
// Step to next GEMM group and update data tile bounds.
offset
+=
grid_size_grp
;
group_id
++
;
grid_size_grp
=
p_gemm_descs
[
group_id
].
tile_count
;
const
index_t
A_m_tile_offset
=
block_m_id
*
MPerBlock
;
gemm_tile_id_start
=
offset
;
const
index_t
A_k_tile_offset
=
k_batch_id
*
KPerBlock
;
gemm_tile_id_end
=
offset
+
grid_size_grp
;
const
index_t
A_thread_tile_m_idx
=
get_thread_local_1d_id
()
/
NPerBlock
;
}
const
index_t
B_n_tile_offset
=
block_n_id
*
NPerBlock
;
const
index_t
M
=
p_gemm_descs
[
group_id
].
M
;
const
index_t
B_k_tile_offset
=
k_batch_id
*
KPerBlock
;
const
index_t
N
=
p_gemm_descs
[
group_id
].
N
;
const
index_t
B_thread_tile_n_idx
=
get_thread_local_1d_id
()
%
NPerBlock
;
const
index_t
K
=
p_gemm_descs
[
group_id
].
K
;
for
(
index_t
k
=
0
;
k
<
KPerBlock
;
++
k
)
const
auto
p_A
=
p_gemm_descs
[
group_id
].
p_A
;
{
const
auto
p_B
=
p_gemm_descs
[
group_id
].
p_B
;
partial_result
+=
const
auto
p_C
=
p_gemm_descs
[
group_id
].
p_C
;
p_A
[(
A_m_tile_offset
+
A_thread_tile_m_idx
)
*
stride_a
+
A_k_tile_offset
+
k
]
*
p_B
[(
B_k_tile_offset
+
k
)
*
stride_b
+
B_n_tile_offset
+
B_thread_tile_n_idx
];
}
}
while
(
work_scheduler
.
GetNextTile
()
&&
b2c_tile_map
.
GetNextKTileIdx
());
// if next [M,N] tile
const
auto
c_grid_desc_m_n
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
M
,
N
));
if
(
!
b2c_tile_map
.
IsFirstKSplitBlock
(
work_scheduler
.
tiles_per_block_
))
BlockToCTileMap_LinearKSplit
<
MPerBlock
,
NPerBlock
>
b2c_tile_map
(
c_grid_desc_m_n
,
k_batch
);
{
// Assume we have MPerBlock x NPerBlock tile per each workgroup in contiguous memory.
p_workspace
[
get_block_1d_id
()
*
MPerBlock
*
NPerBlock
+
get_thread_local_1d_id
()]
=
partial_result
;
}
work_scheduler
.
FlagFinished
(
k_batch
,
b2c_tile_map
.
GetOutputTileIdx
())
;
float
partial_result
=
0.
f
;
// The workgroup which processed first K tile accumulates results and stores to GMEM
constexpr
auto
I0
=
Number
<
0
>
{};
if
(
b2c_tile_map
.
IsFirstKSplitBlock
(
work_scheduler
.
tiles_per_block_
))
constexpr
auto
I1
=
Number
<
1
>
{};
{
// Wait untill all other blocks for this [M,N] tile store their results.
// Assume MK-KN-MN data layout
work_scheduler
.
WaitForNeighbours
(
k_batch
,
b2c_tile_map
.
GetOutputTileIdx
());
const
index_t
stride_a
=
K
;
const
index_t
stride_b
=
N
;
const
index_t
stride_c
=
N
;
// K is the contiguous dim in memory, as well as fastest changing dim in B2C mapping.
const
auto
block_work_idx
=
b2c_tile_map
.
CalculateBottomIndex
(
work_scheduler
.
tile_id_
-
offset
);
// accumulate partial results
const
index_t
block_m_id
=
__builtin_amdgcn_readfirstlane
(
block_work_idx
[
I0
]);
const
index_t
workgroups_per_dim
=
const
index_t
block_n_id
=
__builtin_amdgcn_readfirstlane
(
block_work_idx
[
I1
]);
(
k_batch
+
work_scheduler
.
tiles_per_block_
-
1
)
/
work_scheduler
.
tiles_per_block_
;
for
(
index_t
i
=
0
;
i
<
workgroups_per_dim
;
++
i
)
do
{
const
index_t
k_batch_id
=
__builtin_amdgcn_readfirstlane
(
b2c_tile_map
.
GetTileKIdx
());
const
index_t
A_m_tile_offset
=
block_m_id
*
MPerBlock
;
const
index_t
A_k_tile_offset
=
k_batch_id
*
KPerBlock
;
const
index_t
A_thread_tile_m_idx
=
get_thread_local_1d_id
()
/
NPerBlock
;
const
index_t
B_n_tile_offset
=
block_n_id
*
NPerBlock
;
const
index_t
B_k_tile_offset
=
k_batch_id
*
KPerBlock
;
const
index_t
B_thread_tile_n_idx
=
get_thread_local_1d_id
()
%
NPerBlock
;
for
(
index_t
k
=
0
;
k
<
KPerBlock
;
++
k
)
{
partial_result
+=
p_A
[(
A_m_tile_offset
+
A_thread_tile_m_idx
)
*
stride_a
+
A_k_tile_offset
+
k
]
*
p_B
[(
B_k_tile_offset
+
k
)
*
stride_b
+
B_n_tile_offset
+
B_thread_tile_n_idx
];
}
}
while
(
work_scheduler
.
GetNextTile
()
&&
b2c_tile_map
.
GetNextKTileIdx
());
// if next [M,N] tile
if
(
!
b2c_tile_map
.
IsFirstKSplitBlock
())
{
{
partial_result
+=
p_workspace
[(
get_block_1d_id
())
*
MPerBlock
*
NPerBlock
+
// Assume we have MPerBlock x NPerBlock tile per each workgroup in contiguous memory.
i
*
MPerBlock
*
NPerBlock
+
get_thread_local_1d_id
()];
p_workspace
[
get_block_1d_id
()
*
MPerBlock
*
NPerBlock
+
get_thread_local_1d_id
()]
=
partial_result
;
}
}
// write result
const
index_t
output_tile_idx
=
b2c_tile_map
.
GetOutputTileIdx
();
const
index_t
C_m_tile_offset
=
block_m_id
*
MPerBlock
;
const
index_t
output_tile_idx_offset
=
offset
/
k_batch
;
const
index_t
C_thread_tile_m_idx
=
get_thread_local_1d_id
()
/
NPerBlock
;
const
index_t
C_n_tile_offset
=
block_n_id
*
NPerBlock
;
const
index_t
C_thread_tile_n_idx
=
get_thread_local_1d_id
()
%
NPerBlock
;
p_C
[(
C_m_tile_offset
+
C_thread_tile_m_idx
)
*
stride_c
+
C_n_tile_offset
+
work_scheduler
.
FlagFinished
(
k_batch
,
output_tile_idx
,
output_tile_idx_offset
);
C_thread_tile_n_idx
]
=
partial_result
;
}
// The workgroup which processed first K tile accumulates results and stores to GMEM
if
(
b2c_tile_map
.
IsFirstKSplitBlock
())
{
// Wait untill all other blocks for this [M,N] tile store their results.
work_scheduler
.
WaitForNeighbours
(
k_batch
,
output_tile_idx
,
output_tile_idx_offset
);
// Accumulate partial results. We can have different # of workgroups to reduce, thus we
// read actual flag value.
const
index_t
flag_v
=
__builtin_amdgcn_readfirstlane
(
work_scheduler
.
GetFlagValue
(
k_batch
,
output_tile_idx
,
output_tile_idx_offset
));
for
(
index_t
i
=
1
;
i
<
flag_v
;
++
i
)
{
partial_result
+=
p_workspace
[(
get_block_1d_id
())
*
MPerBlock
*
NPerBlock
+
i
*
MPerBlock
*
NPerBlock
+
get_thread_local_1d_id
()];
}
// Signal waiting blocks that they can start use their workspace.
work_scheduler
.
Reset
(
k_batch
,
output_tile_idx
,
output_tile_idx_offset
);
// write result
const
index_t
C_m_tile_offset
=
block_m_id
*
MPerBlock
;
const
index_t
C_thread_tile_m_idx
=
get_thread_local_1d_id
()
/
NPerBlock
;
const
index_t
C_n_tile_offset
=
block_n_id
*
NPerBlock
;
const
index_t
C_thread_tile_n_idx
=
get_thread_local_1d_id
()
%
NPerBlock
;
p_C
[(
C_m_tile_offset
+
C_thread_tile_m_idx
)
*
stride_c
+
C_n_tile_offset
+
C_thread_tile_n_idx
]
=
partial_result
;
}
else
{
work_scheduler
.
WaitForReduction
(
k_batch
,
output_tile_idx
,
output_tile_idx_offset
);
}
}
while
(
work_scheduler
.
HasTile
());
#else
#else
ignore
=
p_input
;
ignore
=
p_gemm_descs
;
ignore
=
p_output
;
ignore
=
p_workspace
;
ignore
=
p_workspace
;
ignore
=
p_flags
;
ignore
=
p_flags
;
ignore
=
tile_count
;
ignore
=
tile_count
;
...
@@ -126,7 +191,7 @@ __global__ void gemm_naive_strided_tile_loop_reduce(index_t M,
...
@@ -126,7 +191,7 @@ __global__ void gemm_naive_strided_tile_loop_reduce(index_t M,
}
// namespace
}
// namespace
template
<
index_t
BlockSize
,
index_t
MPerBlock
,
index_t
NPerBlock
,
index_t
KPerBlock
>
template
<
index_t
BlockSize
,
index_t
MPerBlock
,
index_t
NPerBlock
,
index_t
KPerBlock
>
struct
GemmStridedTileLoopReduce
struct
Grouped
GemmStridedTileLoopReduce
{
{
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
AElementOp
=
PassThrough
;
using
AElementOp
=
PassThrough
;
...
@@ -139,7 +204,7 @@ struct GemmStridedTileLoopReduce
...
@@ -139,7 +204,7 @@ struct GemmStridedTileLoopReduce
using
AccDataType
=
float
;
using
AccDataType
=
float
;
constexpr
static
auto
DeviceGemmKernel
=
constexpr
static
auto
DeviceGemmKernel
=
gemm_naive_strided_tile_loop_reduce
<
MPerBlock
,
NPerBlock
,
KPerBlock
>
;
grouped_
gemm_naive_strided_tile_loop_reduce
<
MPerBlock
,
NPerBlock
,
KPerBlock
>
;
using
ReferenceGemmInstance
=
ck
::
tensor_operation
::
host
::
ReferenceGemm
<
ADataType
,
using
ReferenceGemmInstance
=
ck
::
tensor_operation
::
host
::
ReferenceGemm
<
ADataType
,
BDataType
,
BDataType
,
...
@@ -149,32 +214,75 @@ struct GemmStridedTileLoopReduce
...
@@ -149,32 +214,75 @@ struct GemmStridedTileLoopReduce
BElementOp
,
BElementOp
,
CElementOp
>
;
CElementOp
>
;
GemmStridedTileLoopReduce
()
=
default
;
Grouped
GemmStridedTileLoopReduce
()
=
default
;
bool
Run
(
index_t
M
,
index_t
N
,
index_t
K
,
index_t
k_batch
)
bool
Run
(
std
::
vector
<
index_t
>
Ms
,
std
::
vector
<
index_t
>
Ns
,
std
::
vector
<
index_t
>
Ks
,
index_t
k_batch
,
index_t
grid_size
)
{
{
Tensor
<
float
>
a_m_k
(
HostTensorDescriptor
({
M
,
K
},
{
K
,
1
}
));
EXPECT_TRUE
(
Ms
.
size
()
==
Ns
.
size
()
&&
Ms
.
size
()
==
Ks
.
size
(
));
Tensor
<
float
>
b_k_n
(
HostTensorDescriptor
({
K
,
N
},
{
N
,
1
})
);
std
::
size_t
group_count
=
Ms
.
size
(
);
ck
::
utils
::
FillUniformDistributionIntegerValue
<
ADataType
>
{
-
5.
f
,
5.
f
}(
a_m_k
);
std
::
vector
<
Tensor
<
float
>>
a_m_k
;
ck
::
utils
::
FillUniformDistributionIntegerValue
<
BDataType
>
{
-
5.
f
,
5.
f
}(
b_k_n
);
std
::
vector
<
Tensor
<
float
>>
b_k_n
;
std
::
vector
<
Tensor
<
float
>>
c_m_n_host
;
std
::
vector
<
Tensor
<
float
>>
c_m_n_device
;
Tensor
<
float
>
c_m_n_host
(
HostTensorDescriptor
({
M
,
N
},
{
N
,
1
}));
using
DeviceMemPtr
=
std
::
unique_ptr
<
DeviceMem
>
;
Tensor
<
float
>
c_m_n_device
(
HostTensorDescriptor
({
M
,
N
},
{
N
,
1
}));
DeviceMem
a_m_k_device_buf
(
sizeof
(
float
)
*
a_m_k
.
mDesc
.
GetElementSpaceSize
());
std
::
vector
<
DeviceMemPtr
>
a_m_k_device_buf
;
DeviceMem
b_k_n_device_buf
(
sizeof
(
float
)
*
b_k_n
.
mDesc
.
GetElementSpaceSize
());
std
::
vector
<
DeviceMemPtr
>
b_k_n_device_buf
;
DeviceMem
c_m_n_device_buf
(
sizeof
(
float
)
*
c_m_n_device
.
mDesc
.
GetElementSpaceSize
());
std
::
vector
<
DeviceMemPtr
>
c_m_n_device_buf
;
std
::
vector
<
GemmArgDesc
>
gemm_descs
;
gemm_descs
.
reserve
(
group_count
);
index_t
tile_count
=
0
;
for
(
std
::
size_t
i
=
0
;
i
<
group_count
;
++
i
)
{
a_m_k
.
push_back
(
Tensor
<
float
>
(
HostTensorDescriptor
({
Ms
[
i
],
Ks
[
i
]},
{
Ks
[
i
],
1
})));
b_k_n
.
push_back
(
Tensor
<
float
>
(
HostTensorDescriptor
({
Ks
[
i
],
Ns
[
i
]},
{
Ns
[
i
],
1
})));
c_m_n_host
.
push_back
(
Tensor
<
float
>
(
HostTensorDescriptor
({
Ms
[
i
],
Ns
[
i
]},
{
Ns
[
i
],
1
})));
c_m_n_device
.
push_back
(
Tensor
<
float
>
(
HostTensorDescriptor
({
Ms
[
i
],
Ns
[
i
]},
{
Ns
[
i
],
1
})));
ck
::
utils
::
FillUniformDistributionIntegerValue
<
ADataType
>
{
-
5.
f
,
5.
f
}(
a_m_k
[
i
]);
ck
::
utils
::
FillUniformDistributionIntegerValue
<
BDataType
>
{
-
5.
f
,
5.
f
}(
b_k_n
[
i
]);
c_m_n_host
[
i
].
SetZero
();
c_m_n_device
[
i
].
SetZero
();
a_m_k_device_buf
.
emplace_back
(
std
::
make_unique
<
DeviceMem
>
(
sizeof
(
float
)
*
a_m_k
[
i
].
mDesc
.
GetElementSpaceSize
()));
b_k_n_device_buf
.
emplace_back
(
std
::
make_unique
<
DeviceMem
>
(
sizeof
(
float
)
*
b_k_n
[
i
].
mDesc
.
GetElementSpaceSize
()));
c_m_n_device_buf
.
emplace_back
(
std
::
make_unique
<
DeviceMem
>
(
sizeof
(
float
)
*
c_m_n_device
[
i
].
mDesc
.
GetElementSpaceSize
()));
a_m_k_device_buf
[
i
]
->
ToDevice
(
a_m_k
[
i
].
mData
.
data
());
b_k_n_device_buf
[
i
]
->
ToDevice
(
b_k_n
[
i
].
mData
.
data
());
c_m_n_device_buf
[
i
]
->
SetZero
();
BlockToCTileMap_LinearKSplit
<
MPerBlock
,
NPerBlock
>
b2c_tile_map
(
Ms
[
i
],
Ns
[
i
],
k_batch
);
index_t
grp_tile_count
=
b2c_tile_map
.
CalculateGridSize
(
Ms
[
i
],
Ns
[
i
]);
tile_count
+=
grp_tile_count
;
gemm_descs
.
emplace_back
(
Ms
[
i
],
Ns
[
i
],
Ks
[
i
],
reinterpret_cast
<
float
*>
(
a_m_k_device_buf
[
i
]
->
GetDeviceBuffer
()),
reinterpret_cast
<
float
*>
(
b_k_n_device_buf
[
i
]
->
GetDeviceBuffer
()),
reinterpret_cast
<
float
*>
(
c_m_n_device_buf
[
i
]
->
GetDeviceBuffer
()),
grp_tile_count
);
}
a_m_k_device_buf
.
ToDevice
(
a_m_k
.
mData
.
data
());
DeviceMem
gemm_descs_device_buf
{
gemm_descs
.
size
()
*
sizeof
(
GemmArgDesc
)};
b_k_n_device_buf
.
ToDevice
(
b_k_n
.
mData
.
data
());
gemm_descs_device_buf
.
ToDevice
(
gemm_descs
.
data
());
c_m_n_device_buf
.
SetZero
();
c_m_n_host
.
SetZero
();
DeviceMem
gemm_workspace
,
gemm_flags
;
DeviceMem
gemm_workspace
,
gemm_flags
;
BlockToCTileMap_LinearKSplit
<
MPerBlock
,
NPerBlock
>
b2c_tile_map
(
M
,
N
,
k_batch
);
const
index_t
tile_count
=
b2c_tile_map
.
CalculateGridSize
(
M
,
N
);
const
index_t
grid_size
=
tile_count
/
4
;
const
index_t
tiles_per_block
=
(
tile_count
+
grid_size
-
1
)
/
grid_size
;
const
index_t
tiles_per_block
=
(
tile_count
+
grid_size
-
1
)
/
grid_size
;
// This is the number of MN-output tiles which we cover with workgroups.
// This is the number of MN-output tiles which we cover with workgroups.
// We launch k_batch / tiles_per_block workgroups for each output tile.
// We launch k_batch / tiles_per_block workgroups for each output tile.
...
@@ -186,21 +294,17 @@ struct GemmStridedTileLoopReduce
...
@@ -186,21 +294,17 @@ struct GemmStridedTileLoopReduce
gemm_workspace
.
SetZero
();
gemm_workspace
.
SetZero
();
gemm_flags
.
SetZero
();
gemm_flags
.
SetZero
();
launch_and_time_kernel
(
StreamConfig
{
nullptr
,
false
},
launch_and_time_kernel
(
DeviceGemmKernel
,
StreamConfig
{
nullptr
,
false
},
dim3
(
grid_size
),
DeviceGemmKernel
,
dim3
(
BlockSize
),
dim3
(
grid_size
),
0
,
dim3
(
BlockSize
),
M
,
0
,
N
,
reinterpret_cast
<
const
GemmArgDesc
*>
(
gemm_descs_device_buf
.
GetDeviceBuffer
()),
K
,
reinterpret_cast
<
float
*>
(
gemm_workspace
.
GetDeviceBuffer
()),
reinterpret_cast
<
const
float
*>
(
a_m_k_device_buf
.
GetDeviceBuffer
()),
reinterpret_cast
<
uint32_t
*>
(
gemm_flags
.
GetDeviceBuffer
()),
reinterpret_cast
<
const
float
*>
(
b_k_n_device_buf
.
GetDeviceBuffer
()),
tile_count
,
reinterpret_cast
<
float
*>
(
c_m_n_device_buf
.
GetDeviceBuffer
()),
k_batch
);
reinterpret_cast
<
float
*>
(
gemm_workspace
.
GetDeviceBuffer
()),
reinterpret_cast
<
uint32_t
*>
(
gemm_flags
.
GetDeviceBuffer
()),
tile_count
,
k_batch
);
auto
a_element_op
=
AElementOp
{};
auto
a_element_op
=
AElementOp
{};
auto
b_element_op
=
BElementOp
{};
auto
b_element_op
=
BElementOp
{};
...
@@ -209,48 +313,155 @@ struct GemmStridedTileLoopReduce
...
@@ -209,48 +313,155 @@ struct GemmStridedTileLoopReduce
auto
ref_gemm
=
ReferenceGemmInstance
{};
auto
ref_gemm
=
ReferenceGemmInstance
{};
auto
ref_invoker
=
ref_gemm
.
MakeInvoker
();
auto
ref_invoker
=
ref_gemm
.
MakeInvoker
();
auto
ref_argument
=
ref_gemm
.
MakeArgument
(
bool
pass
=
true
;
a_m_k
,
b_k_n
,
c_m_n_host
,
a_element_op
,
b_element_op
,
c_element_op
);
ref_invoker
.
Run
(
ref_argument
);
for
(
std
::
size_t
i
=
0
;
i
<
group_count
;
++
i
)
c_m_n_device_buf
.
FromDevice
(
c_m_n_device
.
mData
.
data
());
{
auto
ref_argument
=
ref_gemm
.
MakeArgument
(
a_m_k
[
i
],
b_k_n
[
i
],
c_m_n_host
[
i
],
a_element_op
,
b_element_op
,
c_element_op
);
return
ck
::
utils
::
check_err
(
c_m_n_device
,
c_m_n_host
);
ref_invoker
.
Run
(
ref_argument
);
c_m_n_device_buf
[
i
]
->
FromDevice
(
c_m_n_device
[
i
].
mData
.
data
());
pass
=
pass
&&
ck
::
utils
::
check_err
(
c_m_n_device
[
i
],
c_m_n_host
[
i
]);
}
return
pass
;
}
}
};
};
TEST
(
TestStridedReductionTileLoop
,
SingleDataTile
)
TEST
(
TestStridedReductionTileLoop
,
GroupedGemm_
SingleDataTile
)
{
{
constexpr
index_t
MPerBlock
=
8
;
constexpr
index_t
MPerBlock
=
8
;
constexpr
index_t
NPerBlock
=
32
;
constexpr
index_t
NPerBlock
=
32
;
constexpr
index_t
KPerBlock
=
32
;
constexpr
index_t
KPerBlock
=
32
;
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
BlockSize
=
256
;
const
index_t
kbatch
=
4
;
const
index_t
kbatch
=
4
;
const
index_t
grid_size
=
4
;
std
::
vector
<
index_t
>
Ms
(
1
,
MPerBlock
);
std
::
vector
<
index_t
>
Ns
(
1
,
NPerBlock
);
std
::
vector
<
index_t
>
Ks
(
1
,
KPerBlock
*
kbatch
);
EXPECT_TRUE
((
GemmStridedTileLoopReduce
<
BlockSize
,
MPerBlock
,
NPerBlock
,
KPerBlock
>
{}.
Run
(
EXPECT_TRUE
((
Grouped
GemmStridedTileLoopReduce
<
BlockSize
,
MPerBlock
,
NPerBlock
,
KPerBlock
>
{}.
Run
(
M
PerBlock
,
NPerBlock
,
KPerBlock
*
kbatch
,
kbatch
)));
M
s
,
Ns
,
Ks
,
kbatch
,
grid_size
)));
}
}
TEST
(
TestStridedReductionTileLoop
,
SingleOutputMultipleDataTiles
)
TEST
(
TestStridedReductionTileLoop
,
GroupedGemm_
SingleOutputMultipleDataTiles
)
{
{
constexpr
index_t
MPerBlock
=
8
;
constexpr
index_t
MPerBlock
=
8
;
constexpr
index_t
NPerBlock
=
32
;
constexpr
index_t
NPerBlock
=
32
;
constexpr
index_t
KPerBlock
=
32
;
constexpr
index_t
KPerBlock
=
32
;
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
BlockSize
=
256
;
const
index_t
kbatch
=
16
;
const
index_t
kbatch
=
16
;
const
index_t
grid_size
=
4
;
EXPECT_TRUE
((
GemmStridedTileLoopReduce
<
BlockSize
,
MPerBlock
,
NPerBlock
,
KPerBlock
>
{}.
Run
(
std
::
vector
<
index_t
>
Ms
(
1
,
MPerBlock
);
MPerBlock
,
NPerBlock
,
KPerBlock
*
kbatch
,
kbatch
)));
std
::
vector
<
index_t
>
Ns
(
1
,
NPerBlock
);
std
::
vector
<
index_t
>
Ks
(
1
,
KPerBlock
*
kbatch
);
EXPECT_TRUE
((
GroupedGemmStridedTileLoopReduce
<
BlockSize
,
MPerBlock
,
NPerBlock
,
KPerBlock
>
{}.
Run
(
Ms
,
Ns
,
Ks
,
kbatch
,
grid_size
)));
}
}
TEST
(
TestStridedReductionTileLoop
,
MultipleDataTiles
)
TEST
(
TestStridedReductionTileLoop
,
GroupedGemm_
MultipleDataTiles
)
{
{
constexpr
index_t
MPerBlock
=
8
;
constexpr
index_t
MPerBlock
=
8
;
constexpr
index_t
NPerBlock
=
32
;
constexpr
index_t
NPerBlock
=
32
;
constexpr
index_t
KPerBlock
=
32
;
constexpr
index_t
KPerBlock
=
32
;
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
BlockSize
=
256
;
const
index_t
kbatch
=
16
;
const
index_t
kbatch
=
16
;
const
index_t
grid_size
=
64
;
std
::
vector
<
index_t
>
Ms
(
1
,
MPerBlock
*
4
);
std
::
vector
<
index_t
>
Ns
(
1
,
NPerBlock
*
4
);
std
::
vector
<
index_t
>
Ks
(
1
,
KPerBlock
*
kbatch
);
EXPECT_TRUE
((
GroupedGemmStridedTileLoopReduce
<
BlockSize
,
MPerBlock
,
NPerBlock
,
KPerBlock
>
{}.
Run
(
Ms
,
Ns
,
Ks
,
kbatch
,
grid_size
)));
}
TEST
(
TestStridedReductionTileLoop
,
GroupedGemm_MultipleOutputDataTilesPerBlock_1Group
)
{
constexpr
index_t
MPerBlock
=
8
;
constexpr
index_t
NPerBlock
=
32
;
constexpr
index_t
KPerBlock
=
32
;
constexpr
index_t
BlockSize
=
256
;
const
index_t
kbatch
=
6
;
const
index_t
grid_size
=
3
;
std
::
vector
<
index_t
>
Ms
(
1
,
MPerBlock
*
2
);
std
::
vector
<
index_t
>
Ns
(
1
,
NPerBlock
);
std
::
vector
<
index_t
>
Ks
(
1
,
KPerBlock
*
kbatch
);
EXPECT_TRUE
((
GroupedGemmStridedTileLoopReduce
<
BlockSize
,
MPerBlock
,
NPerBlock
,
KPerBlock
>
{}.
Run
(
Ms
,
Ns
,
Ks
,
kbatch
,
grid_size
)));
}
TEST
(
TestStridedReductionTileLoop
,
GroupedGemm_MultipleOutputDataTilesPerBlock_NGroup
)
{
constexpr
index_t
MPerBlock
=
8
;
constexpr
index_t
NPerBlock
=
32
;
constexpr
index_t
KPerBlock
=
32
;
constexpr
index_t
BlockSize
=
256
;
const
index_t
kbatch
=
6
;
const
index_t
grid_size
=
6
;
std
::
vector
<
index_t
>
Ms
(
2
,
MPerBlock
*
2
);
std
::
vector
<
index_t
>
Ns
(
2
,
NPerBlock
);
std
::
vector
<
index_t
>
Ks
(
2
,
KPerBlock
*
kbatch
);
EXPECT_TRUE
((
GroupedGemmStridedTileLoopReduce
<
BlockSize
,
MPerBlock
,
NPerBlock
,
KPerBlock
>
{}.
Run
(
Ms
,
Ns
,
Ks
,
kbatch
,
grid_size
)));
}
TEST
(
TestStridedReductionTileLoop
,
GroupedGemm_CrossGroups_CrossK_TilePerBlockLTKBatch
)
{
constexpr
index_t
MPerBlock
=
8
;
constexpr
index_t
NPerBlock
=
32
;
constexpr
index_t
KPerBlock
=
32
;
constexpr
index_t
BlockSize
=
256
;
const
index_t
kbatch
=
5
;
const
index_t
grid_size
=
7
;
std
::
vector
<
index_t
>
Ms
(
2
,
MPerBlock
*
2
);
std
::
vector
<
index_t
>
Ns
(
2
,
NPerBlock
);
std
::
vector
<
index_t
>
Ks
(
2
,
KPerBlock
*
kbatch
);
EXPECT_TRUE
((
GroupedGemmStridedTileLoopReduce
<
BlockSize
,
MPerBlock
,
NPerBlock
,
KPerBlock
>
{}.
Run
(
Ms
,
Ns
,
Ks
,
kbatch
,
grid_size
)));
}
TEST
(
TestStridedReductionTileLoop
,
GroupedGemm_CrossGroups_CrossK_TilePerBlockGTKBatch
)
{
constexpr
index_t
MPerBlock
=
8
;
constexpr
index_t
NPerBlock
=
32
;
constexpr
index_t
KPerBlock
=
32
;
constexpr
index_t
BlockSize
=
256
;
const
index_t
kbatch
=
5
;
const
index_t
grid_size
=
5
;
std
::
vector
<
index_t
>
Ms
(
2
,
MPerBlock
*
2
);
std
::
vector
<
index_t
>
Ns
(
2
,
NPerBlock
*
2
);
std
::
vector
<
index_t
>
Ks
(
2
,
KPerBlock
*
kbatch
);
EXPECT_TRUE
((
GroupedGemmStridedTileLoopReduce
<
BlockSize
,
MPerBlock
,
NPerBlock
,
KPerBlock
>
{}.
Run
(
Ms
,
Ns
,
Ks
,
kbatch
,
grid_size
)));
}
TEST
(
TestStridedReductionTileLoop
,
GroupedGemm_CrossGroups_CrossK_TilePerBlockGTKBatch2
)
{
constexpr
index_t
MPerBlock
=
8
;
constexpr
index_t
NPerBlock
=
32
;
constexpr
index_t
KPerBlock
=
32
;
constexpr
index_t
BlockSize
=
256
;
const
index_t
kbatch
=
5
;
// The covered number of tiles is more than actual data tiles.
const
index_t
grid_size
=
6
;
std
::
vector
<
index_t
>
Ms
(
2
,
MPerBlock
*
2
);
std
::
vector
<
index_t
>
Ns
(
2
,
NPerBlock
*
2
);
std
::
vector
<
index_t
>
Ks
(
2
,
KPerBlock
*
kbatch
);
EXPECT_TRUE
((
GemmStridedTileLoopReduce
<
BlockSize
,
MPerBlock
,
NPerBlock
,
KPerBlock
>
{}.
Run
(
EXPECT_TRUE
((
Grouped
GemmStridedTileLoopReduce
<
BlockSize
,
MPerBlock
,
NPerBlock
,
KPerBlock
>
{}.
Run
(
M
PerBlock
*
4
,
NPerBlock
*
4
,
KPerBlock
*
kbatch
,
kbatch
)));
M
s
,
Ns
,
Ks
,
kbatch
,
grid_size
)));
}
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment