Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
1f91449d
Commit
1f91449d
authored
Oct 25, 2024
by
Jakub Piasecki
Browse files
first fixes for gemm
parent
824809c1
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
118 additions
and
35 deletions
+118
-35
example/ck_tile/03_gemm/gemm_basic_mem_pipeline.cpp
example/ck_tile/03_gemm/gemm_basic_mem_pipeline.cpp
+8
-3
example/ck_tile/03_gemm/run_gemm_example.inc
example/ck_tile/03_gemm/run_gemm_example.inc
+1
-0
include/ck_tile/core/config.hpp
include/ck_tile/core/config.hpp
+1
-1
include/ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1.hpp
...ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1.hpp
+11
-0
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem.hpp
.../ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem.hpp
+16
-0
include/ck_tile/ops/gemm/warp/warp_gemm_impl.hpp
include/ck_tile/ops/gemm/warp/warp_gemm_impl.hpp
+12
-0
test/ck_tile/gemm/test_gemm_mem_pipeline.cpp
test/ck_tile/gemm/test_gemm_mem_pipeline.cpp
+1
-0
test/ck_tile/gemm/test_gemm_mem_pipeline_ut_cases.inc
test/ck_tile/gemm/test_gemm_mem_pipeline_ut_cases.inc
+50
-28
test/ck_tile/gemm/test_gemm_mem_pipeline_util.hpp
test/ck_tile/gemm/test_gemm_mem_pipeline_util.hpp
+18
-3
No files found.
example/ck_tile/03_gemm/gemm_basic_mem_pipeline.cpp
View file @
1f91449d
...
@@ -57,8 +57,13 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s)
...
@@ -57,8 +57,13 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s)
const
bool
has_hot_loop
=
BaseGemmPipeline
::
BlockHasHotloop
(
num_loop
);
const
bool
has_hot_loop
=
BaseGemmPipeline
::
BlockHasHotloop
(
num_loop
);
const
ck_tile
::
TailNumber
tail_num
=
BaseGemmPipeline
::
GetBlockLoopTailNum
(
num_loop
);
const
ck_tile
::
TailNumber
tail_num
=
BaseGemmPipeline
::
GetBlockLoopTailNum
(
num_loop
);
float
ave_time
{
0
};
printf
(
"PrefetchStages: %d
\n
"
,
BaseGemmPipeline
::
PrefetchStages
);
printf
(
"num_loop: %d
\n
"
,
num_loop
);
printf
(
"has_hot_loop: %d
\n
"
,
has_hot_loop
);
printf
(
"tail_num: %d
\n
"
,
static_cast
<
int
>
(
tail_num
));
float
ave_time
{
0
};
const
auto
Run
=
[
&
](
const
auto
has_hot_loop_
,
const
auto
tail_number_
)
{
const
auto
Run
=
[
&
](
const
auto
has_hot_loop_
,
const
auto
tail_number_
)
{
constexpr
bool
has_hot_loop_v
=
has_hot_loop_
.
value
;
constexpr
bool
has_hot_loop_v
=
has_hot_loop_
.
value
;
constexpr
auto
tail_number_v
=
tail_number_
.
value
;
constexpr
auto
tail_number_v
=
tail_number_
.
value
;
...
@@ -86,7 +91,7 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s)
...
@@ -86,7 +91,7 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s)
const
dim3
grids
=
Kernel
::
GridSize
(
args
.
M
,
args
.
N
,
args
.
kbatch
);
const
dim3
grids
=
Kernel
::
GridSize
(
args
.
M
,
args
.
N
,
args
.
kbatch
);
constexpr
dim3
blocks
=
Kernel
::
BlockSize
();
constexpr
dim3
blocks
=
Kernel
::
BlockSize
();
if
(
s
.
log_level_
>
0
)
if
(
true
)
{
{
std
::
cout
<<
"Lunching kernel with args:"
std
::
cout
<<
"Lunching kernel with args:"
<<
" grid: {"
<<
grids
.
x
<<
", "
<<
grids
.
y
<<
", "
<<
grids
.
z
<<
"}"
<<
" grid: {"
<<
grids
.
x
<<
", "
<<
grids
.
y
<<
", "
<<
grids
.
z
<<
"}"
...
@@ -169,7 +174,7 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s)
...
@@ -169,7 +174,7 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s)
{
{
Run
(
ck_tile
::
bool_constant
<
false
>
{},
Run
(
ck_tile
::
bool_constant
<
false
>
{},
ck_tile
::
integral_constant
<
ck_tile
::
TailNumber
,
ck_tile
::
TailNumber
::
One
>
{});
ck_tile
::
integral_constant
<
ck_tile
::
TailNumber
,
ck_tile
::
TailNumber
::
One
>
{});
}
}
// what if not?
}
}
return
ave_time
;
return
ave_time
;
...
...
example/ck_tile/03_gemm/run_gemm_example.inc
View file @
1f91449d
...
@@ -67,6 +67,7 @@ int run_gemm_example(int argc, char* argv[])
...
@@ -67,6 +67,7 @@ int run_gemm_example(int argc, char* argv[])
int n_repeat = arg_parser.get_int("
repeat
");
int n_repeat = arg_parser.get_int("
repeat
");
using ALayout = ck_tile::tensor_layout::gemm::RowMajor;
using ALayout = ck_tile::tensor_layout::gemm::RowMajor;
//using BLayout = ck_tile::tensor_layout::gemm::ColumnMajor;
using BLayout = ck_tile::tensor_layout::gemm::ColumnMajor;
using BLayout = ck_tile::tensor_layout::gemm::ColumnMajor;
using CLayout = ck_tile::tensor_layout::gemm::RowMajor;
using CLayout = ck_tile::tensor_layout::gemm::RowMajor;
...
...
include/ck_tile/core/config.hpp
View file @
1f91449d
...
@@ -166,7 +166,7 @@
...
@@ -166,7 +166,7 @@
#endif
#endif
#ifndef CK_TILE_DEBUG_LOG
#ifndef CK_TILE_DEBUG_LOG
#define CK_TILE_DEBUG_LOG
0
#define CK_TILE_DEBUG_LOG
1
#endif
#endif
#ifndef __HIP_DEVICE_COMPILE__ // for host code
#ifndef __HIP_DEVICE_COMPILE__ // for host code
...
...
include/ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1.hpp
View file @
1f91449d
...
@@ -138,6 +138,11 @@ struct BlockGemmASmemBSmemCRegV1
...
@@ -138,6 +138,11 @@ struct BlockGemmASmemBSmemCRegV1
constexpr
auto
c_warp_y_index_zeros
=
uniform_sequence_gen_t
<
CWarpDstr
::
NDimY
,
0
>
{};
constexpr
auto
c_warp_y_index_zeros
=
uniform_sequence_gen_t
<
CWarpDstr
::
NDimY
,
0
>
{};
// hot loop:
// hot loop:
// if(threadIdx.x == 0) {
// printf("block gemm\n");
// }
static_for
<
0
,
KIterPerWarp
,
1
>
{}([
&
](
auto
kIter
)
{
static_for
<
0
,
KIterPerWarp
,
1
>
{}([
&
](
auto
kIter
)
{
static_for
<
0
,
MIterPerWarp
,
1
>
{}([
&
](
auto
mIter
)
{
static_for
<
0
,
MIterPerWarp
,
1
>
{}([
&
](
auto
mIter
)
{
// read A warp tensor from A block window
// read A warp tensor from A block window
...
@@ -162,6 +167,12 @@ struct BlockGemmASmemBSmemCRegV1
...
@@ -162,6 +167,12 @@ struct BlockGemmASmemBSmemCRegV1
merge_sequences
(
sequence
<
mIter
,
nIter
>
{},
c_warp_y_index_zeros
),
merge_sequences
(
sequence
<
mIter
,
nIter
>
{},
c_warp_y_index_zeros
),
merge_sequences
(
sequence
<
1
,
1
>
{},
c_warp_y_lengths
),
merge_sequences
(
sequence
<
1
,
1
>
{},
c_warp_y_lengths
),
c_warp_tensor
.
get_thread_buffer
());
c_warp_tensor
.
get_thread_buffer
());
// if(threadIdx.x == 0) {
// printf("C warp\n");
// tile_elementwise_inout([](auto& c) { printf("%f ", static_cast<float>(c));}, c_block_tensor);
// printf("\n");
// }
});
});
});
});
});
});
...
...
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem.hpp
View file @
1f91449d
...
@@ -276,10 +276,25 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
...
@@ -276,10 +276,25 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
// initialize C
// initialize C
tile_elementwise_inout
([](
auto
&
c
)
{
c
=
0
;
},
c_block_tile
);
tile_elementwise_inout
([](
auto
&
c
)
{
c
=
0
;
},
c_block_tile
);
if
(
threadIdx
.
x
==
0
)
{
printf
(
"gemm_pipeline_ag_bg_cr_mem
\n
"
);
printf
(
"A in: "
);
static_for
<
0
,
16
,
1
>
{}([
&
](
auto
i
)
{
printf
(
"%f "
,
static_cast
<
float
>
(
a_block_tiles
.
get
(
I0
{}).
get_thread_buffer
()[
i
]));
});
printf
(
"
\n
B in: "
);
static_for
<
0
,
16
,
1
>
{}([
&
](
auto
i
)
{
printf
(
"%f "
,
static_cast
<
float
>
(
b_block_tiles
.
get
(
I0
{}).
get_thread_buffer
()[
i
]));
});
printf
(
"
\n
"
);
}
// LDS write 0
// LDS write 0
LocalPrefill
(
a_copy_lds_window
,
a_block_tiles
.
get
(
I0
{}),
a_element_func
);
LocalPrefill
(
a_copy_lds_window
,
a_block_tiles
.
get
(
I0
{}),
a_element_func
);
LocalPrefill
(
b_copy_lds_window
,
b_block_tiles
.
get
(
I0
{}),
b_element_func
);
LocalPrefill
(
b_copy_lds_window
,
b_block_tiles
.
get
(
I0
{}),
b_element_func
);
// print a_block_tiles, b_block_tiles
// Global prefetch [1, PrefetchStages]
// Global prefetch [1, PrefetchStages]
static_for
<
1
,
PrefetchStages
,
1
>
{}([
&
](
auto
prefetch_idx
)
{
static_for
<
1
,
PrefetchStages
,
1
>
{}([
&
](
auto
prefetch_idx
)
{
GlobalPrefetch
(
a_block_tiles
.
get
(
number
<
prefetch_idx
>
{}),
a_copy_dram_window
);
GlobalPrefetch
(
a_block_tiles
.
get
(
number
<
prefetch_idx
>
{}),
a_copy_dram_window
);
...
@@ -341,6 +356,7 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
...
@@ -341,6 +356,7 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
if
constexpr
(
TailNum
==
TailNumber
::
One
)
if
constexpr
(
TailNum
==
TailNumber
::
One
)
{
{
//printf("TailNumOne\n");
block_sync_lds
();
block_sync_lds
();
// block_gemm.LocalPrefetch();
// block_gemm.LocalPrefetch();
block_gemm
(
c_block_tile
,
a_lds_gemm_window
,
b_lds_gemm_window
);
block_gemm
(
c_block_tile
,
a_lds_gemm_window
,
b_lds_gemm_window
);
...
...
include/ck_tile/ops/gemm/warp/warp_gemm_impl.hpp
View file @
1f91449d
...
@@ -43,6 +43,18 @@ struct WarpGemmImpl
...
@@ -43,6 +43,18 @@ struct WarpGemmImpl
const
auto
b_vec
=
b
.
get_thread_buffer
().
template
get_as
<
BVec
>()[
I0
];
const
auto
b_vec
=
b
.
get_thread_buffer
().
template
get_as
<
BVec
>()[
I0
];
auto
c_vec
=
c
.
get_thread_buffer
().
template
get_as
<
CVec
>()[
I0
];
auto
c_vec
=
c
.
get_thread_buffer
().
template
get_as
<
CVec
>()[
I0
];
// if(threadIdx.x == 0) {
// for(int i=0; i<AWarpTensor::get_thread_buffer_size(); ++i) {
// printf("A[%d]: %d\n", i, static_cast<int32_t>(a_vec[i]));
// }
// for(int i=0; i<BWarpTensor::get_thread_buffer_size(); ++i) {
// printf("B[%d]: %d\n", i, static_cast<int32_t>(b_vec[i]));
// }
// for(int i=0; i<CWarpTensor::get_thread_buffer_size(); ++i) {
// printf("C[%d]: %d\n", i, static_cast<int32_t>(c_vec[i]));
// }
// }
// c_vec += a_vec * b_vec
// c_vec += a_vec * b_vec
WarpGemmAttribute
{}(
c_vec
,
a_vec
,
b_vec
);
WarpGemmAttribute
{}(
c_vec
,
a_vec
,
b_vec
);
...
...
test/ck_tile/gemm/test_gemm_mem_pipeline.cpp
View file @
1f91449d
...
@@ -18,6 +18,7 @@ using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
...
@@ -18,6 +18,7 @@ using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
using
KernelTypes
=
::
testing
::
Types
<
using
KernelTypes
=
::
testing
::
Types
<
// ALayout, BLayout, CLayout, ADataType, BDataType, AccDataType, CDataType
// ALayout, BLayout, CLayout, ADataType, BDataType, AccDataType, CDataType
std
::
tuple
<
Row
,
Col
,
Row
,
F16
,
F16
,
F32
,
F16
>
std
::
tuple
<
Row
,
Col
,
Row
,
F16
,
F16
,
F32
,
F16
>
//std::tuple< Row, Row, Row, F16, F16, F32, F16>
// TODO: fixme!
// TODO: fixme!
// std::tuple< Col, Row, Row, F16, F16, F32, F16>,
// std::tuple< Col, Row, Row, F16, F16, F32, F16>,
// std::tuple< Row, Row, Row, F16, F16, F32, F16>,
// std::tuple< Row, Row, Row, F16, F16, F32, F16>,
...
...
test/ck_tile/gemm/test_gemm_mem_pipeline_ut_cases.inc
View file @
1f91449d
...
@@ -2,42 +2,64 @@
...
@@ -2,42 +2,64 @@
TYPED_TEST
(
TestCkTileGemmMemPipeline
,
SmallM
)
TYPED_TEST
(
TestCkTileGemmMemPipeline
,
SmallM
)
{
{
std
::
vector
<
int
>
Ms
{
1
,
2
,
3
,
4
,
5
,
6
};
std
::
vector
<
int
>
Ms
{
1
28
};
constexpr
int
N
=
1024
;
std
::
vector
<
int
>
N
s
{
128
};
// M K K N M N
constexpr
int
K
=
320
;
std
::
vector
<
int
>
K
s
{
33
}
;
for
(
int
M
:
Ms
)
for
(
int
M
:
Ms
)
this
->
Run
(
M
,
N
,
K
);
for
(
int
N
:
Ns
)
for
(
int
K
:
Ks
)
this
->
Run
(
M
,
N
,
K
);
}
}
TYPED_TEST
(
TestCkTileGemmMemPipeline
,
MidLargeM
)
//
TYPED_TEST(TestCkTileGemmMemPipeline, MidLargeM)
{
//
{
std
::
vector
<
int
>
Ms
{
127
,
255
,
312
,
799
,
1573
};
//
std::vector<int> Ms{127, 255, 312, 799, 1573};
constexpr
int
N
=
1024
;
//
constexpr int N = 1024;
constexpr
int
K
=
32
0
;
//
constexpr int K = 32
1
;
for
(
int
M
:
Ms
)
//
for(int M : Ms)
this
->
Run
(
M
,
N
,
K
);
//
this->Run(M, N, K);
}
//
}
// TODO: Seems like padding is not working!
// TODO: Seems like padding is not working!
// Works only when K is a multiple of KPerBlock
// Works only when K is a multiple of KPerBlock
TYPED_TEST
(
TestCkTileGemmMemPipeline
,
DISABLED_
PaddK
)
//
TYPED_TEST(TestCkTileGemmMemPipeline, PaddK)
{
//
{
std
::
vector
<
int
>
Ms
{
1
27
};
//
std::vector<int> Ms{1};
constexpr
int
N
=
1
024
;
//
constexpr int N = 1
28
;
constexpr
int
K
=
4
32
;
//
constexpr int K = 32
0
;
for
(
int
M
:
Ms
)
//
for(int M : Ms)
this
->
Run
(
M
,
N
,
K
);
//
this->Run(M, N, K);
}
//
}
TYPED_TEST
(
TestCkTileGemmMemPipeline
,
Regular
)
//
TYPED_TEST(TestCkTileGemmMemPipeline,
PaddKInv
)
{
//
{
std
::
vector
<
int
>
Ms
{
512
};
//
std::vector<int> Ms{
1
};
constexpr
int
N
=
1
024
;
//
constexpr int N = 1
28
;
constexpr
int
K
=
51
2
;
//
constexpr int K =
32
2;
for
(
int
M
:
Ms
)
// for(int M : Ms)
this
->
Run
(
M
,
N
,
K
);
// this->Run(M, N, K);
}
// }
// TYPED_TEST(TestCkTileGemmMemPipeline, PaddKInv2)
// {
// std::vector<int> Ms{1};
// constexpr int N = 128;
// constexpr int K = 346;
// for(int M : Ms)
// this->Run(M, N, K);
// }
// TYPED_TEST(TestCkTileGemmMemPipeline, Regular)
// {
// std::vector<int> Ms{512};
// constexpr int N = 1024;
// constexpr int K = 512;
// for(int M : Ms)
// this->Run(M, N, K);
// }
test/ck_tile/gemm/test_gemm_mem_pipeline_util.hpp
View file @
1f91449d
...
@@ -78,6 +78,10 @@ class TestCkTileGemmMemPipeline : public ::testing::Test
...
@@ -78,6 +78,10 @@ class TestCkTileGemmMemPipeline : public ::testing::Test
const
bool
has_hot_loop
=
BaseGemmPipeline
::
BlockHasHotloop
(
num_loop
);
const
bool
has_hot_loop
=
BaseGemmPipeline
::
BlockHasHotloop
(
num_loop
);
const
ck_tile
::
TailNumber
tail_num
=
BaseGemmPipeline
::
GetBlockLoopTailNum
(
num_loop
);
const
ck_tile
::
TailNumber
tail_num
=
BaseGemmPipeline
::
GetBlockLoopTailNum
(
num_loop
);
std
::
cout
<<
"has hot loop "
<<
has_hot_loop
<<
std
::
endl
;
std
::
cout
<<
"num loop "
<<
num_loop
<<
std
::
endl
;
std
::
cout
<<
"tail_num "
<<
static_cast
<
int32_t
>
(
tail_num
)
-
1
<<
std
::
endl
;
const
auto
Run
=
[
&
](
const
auto
has_hot_loop_
,
const
auto
tail_number_
)
{
const
auto
Run
=
[
&
](
const
auto
has_hot_loop_
,
const
auto
tail_number_
)
{
constexpr
bool
has_hot_loop_v
=
has_hot_loop_
.
value
;
constexpr
bool
has_hot_loop_v
=
has_hot_loop_
.
value
;
constexpr
auto
tail_number_v
=
tail_number_
.
value
;
constexpr
auto
tail_number_v
=
tail_number_
.
value
;
...
@@ -105,7 +109,7 @@ class TestCkTileGemmMemPipeline : public ::testing::Test
...
@@ -105,7 +109,7 @@ class TestCkTileGemmMemPipeline : public ::testing::Test
const
dim3
grids
=
Kernel
::
GridSize
(
args
.
M
,
args
.
N
,
args
.
kbatch
);
const
dim3
grids
=
Kernel
::
GridSize
(
args
.
M
,
args
.
N
,
args
.
kbatch
);
constexpr
dim3
blocks
=
Kernel
::
BlockSize
();
constexpr
dim3
blocks
=
Kernel
::
BlockSize
();
if
(
s
.
log_level_
>
0
)
if
(
true
)
{
{
std
::
cout
<<
"Lunching kernel with args:"
std
::
cout
<<
"Lunching kernel with args:"
<<
" grid: {"
<<
grids
.
x
<<
", "
<<
grids
.
y
<<
", "
<<
grids
.
z
<<
"}"
<<
" grid: {"
<<
grids
.
x
<<
", "
<<
grids
.
y
<<
", "
<<
grids
.
z
<<
"}"
...
@@ -119,14 +123,17 @@ class TestCkTileGemmMemPipeline : public ::testing::Test
...
@@ -119,14 +123,17 @@ class TestCkTileGemmMemPipeline : public ::testing::Test
if
(
has_hot_loop
)
if
(
has_hot_loop
)
{
{
std
::
cout
<<
"has hot loop xx
\n
"
;
// Tail pipeline One to Seven
// Tail pipeline One to Seven
if
(
tail_num
==
ck_tile
::
TailNumber
::
One
)
if
(
tail_num
==
ck_tile
::
TailNumber
::
One
)
{
{
std
::
cout
<<
"tail num one
\n
"
;
Run
(
ck_tile
::
bool_constant
<
true
>
{},
Run
(
ck_tile
::
bool_constant
<
true
>
{},
ck_tile
::
integral_constant
<
ck_tile
::
TailNumber
,
ck_tile
::
TailNumber
::
One
>
{});
ck_tile
::
integral_constant
<
ck_tile
::
TailNumber
,
ck_tile
::
TailNumber
::
One
>
{});
}
}
else
if
(
tail_num
==
ck_tile
::
TailNumber
::
Full
)
else
if
(
tail_num
==
ck_tile
::
TailNumber
::
Full
)
{
{
std
::
cout
<<
"tail num full
\n
"
;
Run
(
ck_tile
::
bool_constant
<
true
>
{},
Run
(
ck_tile
::
bool_constant
<
true
>
{},
ck_tile
::
integral_constant
<
ck_tile
::
TailNumber
,
ck_tile
::
TailNumber
::
Full
>
{});
ck_tile
::
integral_constant
<
ck_tile
::
TailNumber
,
ck_tile
::
TailNumber
::
Full
>
{});
}
}
...
@@ -191,6 +198,7 @@ class TestCkTileGemmMemPipeline : public ::testing::Test
...
@@ -191,6 +198,7 @@ class TestCkTileGemmMemPipeline : public ::testing::Test
// Tail number always 1
// Tail number always 1
if
(
tail_num
==
ck_tile
::
TailNumber
::
One
)
if
(
tail_num
==
ck_tile
::
TailNumber
::
One
)
{
{
std
::
cout
<<
"nohotloop tail num one xx
\n
"
;
Run
(
ck_tile
::
bool_constant
<
false
>
{},
Run
(
ck_tile
::
bool_constant
<
false
>
{},
ck_tile
::
integral_constant
<
ck_tile
::
TailNumber
,
ck_tile
::
TailNumber
::
One
>
{});
ck_tile
::
integral_constant
<
ck_tile
::
TailNumber
,
ck_tile
::
TailNumber
::
One
>
{});
}
}
...
@@ -267,8 +275,10 @@ class TestCkTileGemmMemPipeline : public ::testing::Test
...
@@ -267,8 +275,10 @@ class TestCkTileGemmMemPipeline : public ::testing::Test
ck_tile
::
HostTensor
<
CDataType
>
c_m_n_dev_result
(
ck_tile
::
HostTensor
<
CDataType
>
c_m_n_dev_result
(
f_host_tensor_descriptor
(
M
,
N
,
stride_C
,
CLayout
{}));
f_host_tensor_descriptor
(
M
,
N
,
stride_C
,
CLayout
{}));
ck_tile
::
FillUniformDistributionIntegerValue
<
ADataType
>
{
-
5
,
5
}(
a_m_k
);
ck_tile
::
FillMonotonicSeq
<
ADataType
>
{
0
,
0.01
}(
a_m_k
);
ck_tile
::
FillUniformDistributionIntegerValue
<
BDataType
>
{
-
5
,
5
}(
b_k_n
);
ck_tile
::
FillMonotonicSeq
<
BDataType
>
{
0
,
0.01
}(
b_k_n
);
//ck_tile::FillUniformDistributionIntegerValue<ADataType>{-5, 5}(a_m_k);
//ck_tile::FillUniformDistributionIntegerValue<BDataType>{-5, 5}(b_k_n);
ck_tile
::
DeviceMem
a_m_k_dev_buf
(
a_m_k
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
a_m_k_dev_buf
(
a_m_k
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
b_k_n_dev_buf
(
b_k_n
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
b_k_n_dev_buf
(
b_k_n
.
get_element_space_size_in_bytes
());
...
@@ -291,6 +301,11 @@ class TestCkTileGemmMemPipeline : public ::testing::Test
...
@@ -291,6 +301,11 @@ class TestCkTileGemmMemPipeline : public ::testing::Test
args
.
stride_B
=
stride_B
;
args
.
stride_B
=
stride_B
;
args
.
stride_C
=
stride_C
;
args
.
stride_C
=
stride_C
;
std
::
cout
<<
"kbatch "
<<
kbatch
<<
std
::
endl
;
std
::
cout
<<
"stride A "
<<
stride_A
<<
std
::
endl
;
std
::
cout
<<
"stride B "
<<
stride_B
<<
std
::
endl
;
std
::
cout
<<
"stride C "
<<
stride_C
<<
std
::
endl
;
invoke_gemm
(
args
,
ck_tile
::
stream_config
{
nullptr
,
false
});
invoke_gemm
(
args
,
ck_tile
::
stream_config
{
nullptr
,
false
});
c_m_n_dev_buf
.
FromDevice
(
c_m_n_dev_result
.
data
());
c_m_n_dev_buf
.
FromDevice
(
c_m_n_dev_result
.
data
());
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment