Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
98c80714
Unverified
Commit
98c80714
authored
Oct 10, 2023
by
Bartłomiej Kocot
Committed by
GitHub
Oct 10, 2023
Browse files
Fix MNKPadding in gridwise_gemm_xdlops_v2r3 (#981)
parent
ac9595a9
Changes
7
Show whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
189 additions
and
473 deletions
+189
-473
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp
...k/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp
+2
-1
test/batched_gemm/CMakeLists.txt
test/batched_gemm/CMakeLists.txt
+2
-16
test/batched_gemm/batched_gemm_bf16.cpp
test/batched_gemm/batched_gemm_bf16.cpp
+0
-114
test/batched_gemm/batched_gemm_fp16.cpp
test/batched_gemm/batched_gemm_fp16.cpp
+0
-114
test/batched_gemm/batched_gemm_fp32.cpp
test/batched_gemm/batched_gemm_fp32.cpp
+0
-114
test/batched_gemm/batched_gemm_int8.cpp
test/batched_gemm/batched_gemm_int8.cpp
+0
-114
test/batched_gemm/test_batched_gemm.cpp
test/batched_gemm/test_batched_gemm.cpp
+185
-0
No files found.
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp
View file @
98c80714
...
@@ -945,7 +945,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3_ext
...
@@ -945,7 +945,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3_ext
}
}
}();
}();
if
constexpr
(
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
MNPadding
)
if
constexpr
(
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
MNPadding
||
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
MNKPadding
)
{
{
return
transform_tensor_descriptor
(
c_grid_desc_m_n
,
return
transform_tensor_descriptor
(
c_grid_desc_m_n
,
make_tuple
(
make_right_pad_transform
(
M
,
MPad
-
M
),
make_tuple
(
make_right_pad_transform
(
M
,
MPad
-
M
),
...
...
test/batched_gemm/CMakeLists.txt
View file @
98c80714
...
@@ -2,22 +2,8 @@ list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942)
...
@@ -2,22 +2,8 @@ list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942)
set
(
target 0
)
set
(
target 0
)
foreach
(
gpu IN LISTS GPU_TARGETS
)
foreach
(
gpu IN LISTS GPU_TARGETS
)
if
(
gpu IN_LIST gpu_list AND target EQUAL 0
)
if
(
gpu IN_LIST gpu_list AND target EQUAL 0
)
add_test_executable
(
test_batched_gemm_fp16 batched_gemm_fp16.cpp
)
add_gtest_executable
(
test_batched_gemm test_batched_gemm.cpp
)
if
(
result EQUAL 0
)
target_link_libraries
(
test_batched_gemm PRIVATE utility device_batched_gemm_instance
)
target_link_libraries
(
test_batched_gemm_fp16 PRIVATE utility device_batched_gemm_instance
)
endif
()
add_test_executable
(
test_batched_gemm_fp32 batched_gemm_fp32.cpp
)
if
(
result EQUAL 0
)
target_link_libraries
(
test_batched_gemm_fp32 PRIVATE utility device_batched_gemm_instance
)
endif
()
add_test_executable
(
test_batched_gemm_bf16 batched_gemm_bf16.cpp
)
if
(
result EQUAL 0
)
target_link_libraries
(
test_batched_gemm_bf16 PRIVATE utility device_batched_gemm_instance
)
endif
()
add_test_executable
(
test_batched_gemm_int8 batched_gemm_int8.cpp
)
if
(
result EQUAL 0
)
target_link_libraries
(
test_batched_gemm_int8 PRIVATE utility device_batched_gemm_instance
)
endif
()
set
(
target 1
)
set
(
target 1
)
endif
()
endif
()
endforeach
()
endforeach
()
\ No newline at end of file
test/batched_gemm/batched_gemm_bf16.cpp
deleted
100644 → 0
View file @
ac9595a9
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include "profiler/profile_batched_gemm_impl.hpp"
#include "ck/library/tensor_operation_instance/gpu/batched_gemm.hpp"
namespace
{
using
ADataType
=
ck
::
bhalf_t
;
using
BDataType
=
ck
::
bhalf_t
;
using
CDataType
=
ck
::
bhalf_t
;
using
Row
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
Col
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
}
// namespace
int
main
()
{
int
M
=
256
;
int
N
=
256
;
int
K
=
128
;
int
BatchCount
=
3
;
bool
pass
=
true
;
using
namespace
ck
::
tensor_operation
::
device
;
pass
=
pass
&&
ck
::
profiler
::
profile_batched_gemm_impl
<
ADataType
,
BDataType
,
CDataType
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
DeviceBatchedGemm
<
Row
,
Row
,
Row
,
ADataType
,
BDataType
,
CDataType
,
PassThrough
,
PassThrough
,
PassThrough
>>
(
true
,
1
,
false
,
1
,
M
,
N
,
K
,
K
,
N
,
N
,
M
*
K
,
K
*
N
,
M
*
N
,
BatchCount
);
pass
=
pass
&&
ck
::
profiler
::
profile_batched_gemm_impl
<
ADataType
,
BDataType
,
CDataType
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
DeviceBatchedGemm
<
Row
,
Col
,
Row
,
ADataType
,
BDataType
,
CDataType
,
PassThrough
,
PassThrough
,
PassThrough
>>
(
true
,
1
,
false
,
1
,
M
,
N
,
K
,
K
,
K
,
N
,
M
*
K
,
K
*
N
,
M
*
N
,
BatchCount
);
pass
=
pass
&&
ck
::
profiler
::
profile_batched_gemm_impl
<
ADataType
,
BDataType
,
CDataType
,
Col
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
DeviceBatchedGemm
<
Col
,
Row
,
Row
,
ADataType
,
BDataType
,
CDataType
,
PassThrough
,
PassThrough
,
PassThrough
>>
(
true
,
1
,
false
,
1
,
M
,
N
,
K
,
M
,
N
,
N
,
M
*
K
,
K
*
N
,
M
*
N
,
BatchCount
);
pass
=
pass
&&
ck
::
profiler
::
profile_batched_gemm_impl
<
ADataType
,
BDataType
,
CDataType
,
Col
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
DeviceBatchedGemm
<
Col
,
Col
,
Row
,
ADataType
,
BDataType
,
CDataType
,
PassThrough
,
PassThrough
,
PassThrough
>>
(
true
,
1
,
false
,
1
,
M
,
N
,
K
,
M
,
K
,
N
,
M
*
K
,
K
*
N
,
M
*
N
,
BatchCount
);
std
::
cout
<<
"test BatchedGEMM bf16: "
<<
(
pass
?
"Pass"
:
"Fail"
)
<<
std
::
endl
;
return
pass
?
0
:
1
;
}
test/batched_gemm/batched_gemm_fp16.cpp
deleted
100644 → 0
View file @
ac9595a9
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include "profiler/profile_batched_gemm_impl.hpp"
#include "ck/library/tensor_operation_instance/gpu/batched_gemm.hpp"
namespace
{
using
ADataType
=
ck
::
half_t
;
using
BDataType
=
ck
::
half_t
;
using
CDataType
=
ck
::
half_t
;
using
Row
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
Col
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
}
// namespace
int
main
()
{
int
M
=
512
;
int
N
=
256
;
int
K
=
128
;
int
BatchCount
=
3
;
bool
pass
=
true
;
using
namespace
ck
::
tensor_operation
::
device
;
pass
=
pass
&&
ck
::
profiler
::
profile_batched_gemm_impl
<
ADataType
,
BDataType
,
CDataType
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
DeviceBatchedGemm
<
Row
,
Row
,
Row
,
ADataType
,
BDataType
,
CDataType
,
PassThrough
,
PassThrough
,
PassThrough
>>
(
true
,
1
,
false
,
1
,
M
,
N
,
K
,
K
,
N
,
N
,
M
*
K
,
K
*
N
,
M
*
N
,
BatchCount
);
pass
=
pass
&&
ck
::
profiler
::
profile_batched_gemm_impl
<
ADataType
,
BDataType
,
CDataType
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
DeviceBatchedGemm
<
Row
,
Col
,
Row
,
ADataType
,
BDataType
,
CDataType
,
PassThrough
,
PassThrough
,
PassThrough
>>
(
true
,
1
,
false
,
1
,
M
,
N
,
K
,
K
,
K
,
N
,
M
*
K
,
K
*
N
,
M
*
N
,
BatchCount
);
pass
=
pass
&&
ck
::
profiler
::
profile_batched_gemm_impl
<
ADataType
,
BDataType
,
CDataType
,
Col
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
DeviceBatchedGemm
<
Col
,
Row
,
Row
,
ADataType
,
BDataType
,
CDataType
,
PassThrough
,
PassThrough
,
PassThrough
>>
(
true
,
1
,
false
,
1
,
M
,
N
,
K
,
M
,
N
,
N
,
M
*
K
,
K
*
N
,
M
*
N
,
BatchCount
);
pass
=
pass
&&
ck
::
profiler
::
profile_batched_gemm_impl
<
ADataType
,
BDataType
,
CDataType
,
Col
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
DeviceBatchedGemm
<
Col
,
Col
,
Row
,
ADataType
,
BDataType
,
CDataType
,
PassThrough
,
PassThrough
,
PassThrough
>>
(
true
,
1
,
false
,
1
,
M
,
N
,
K
,
M
,
K
,
N
,
M
*
K
,
K
*
N
,
M
*
N
,
BatchCount
);
std
::
cout
<<
"test BatchedGEMM fp16: "
<<
(
pass
?
"Pass"
:
"Fail"
)
<<
std
::
endl
;
return
pass
?
0
:
1
;
}
test/batched_gemm/batched_gemm_fp32.cpp
deleted
100644 → 0
View file @
ac9595a9
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include "profiler/profile_batched_gemm_impl.hpp"
#include "ck/library/tensor_operation_instance/gpu/batched_gemm.hpp"
namespace
{
using
ADataType
=
float
;
using
BDataType
=
float
;
using
CDataType
=
float
;
using
Row
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
Col
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
}
// namespace
int
main
()
{
int
M
=
256
;
int
N
=
256
;
int
K
=
128
;
int
BatchCount
=
3
;
bool
pass
=
true
;
using
namespace
ck
::
tensor_operation
::
device
;
pass
=
pass
&&
ck
::
profiler
::
profile_batched_gemm_impl
<
ADataType
,
BDataType
,
CDataType
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
DeviceBatchedGemm
<
Row
,
Row
,
Row
,
ADataType
,
BDataType
,
CDataType
,
PassThrough
,
PassThrough
,
PassThrough
>>
(
true
,
1
,
false
,
1
,
M
,
N
,
K
,
K
,
N
,
N
,
M
*
K
,
K
*
N
,
M
*
N
,
BatchCount
);
pass
=
pass
&&
ck
::
profiler
::
profile_batched_gemm_impl
<
ADataType
,
BDataType
,
CDataType
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
DeviceBatchedGemm
<
Row
,
Col
,
Row
,
ADataType
,
BDataType
,
CDataType
,
PassThrough
,
PassThrough
,
PassThrough
>>
(
true
,
1
,
false
,
1
,
M
,
N
,
K
,
K
,
K
,
N
,
M
*
K
,
K
*
N
,
M
*
N
,
BatchCount
);
pass
=
pass
&&
ck
::
profiler
::
profile_batched_gemm_impl
<
ADataType
,
BDataType
,
CDataType
,
Col
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
DeviceBatchedGemm
<
Col
,
Row
,
Row
,
ADataType
,
BDataType
,
CDataType
,
PassThrough
,
PassThrough
,
PassThrough
>>
(
true
,
1
,
false
,
1
,
M
,
N
,
K
,
M
,
N
,
N
,
M
*
K
,
K
*
N
,
M
*
N
,
BatchCount
);
pass
=
pass
&&
ck
::
profiler
::
profile_batched_gemm_impl
<
ADataType
,
BDataType
,
CDataType
,
Col
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
DeviceBatchedGemm
<
Col
,
Col
,
Row
,
ADataType
,
BDataType
,
CDataType
,
PassThrough
,
PassThrough
,
PassThrough
>>
(
true
,
1
,
false
,
1
,
M
,
N
,
K
,
M
,
K
,
N
,
M
*
K
,
K
*
N
,
M
*
N
,
BatchCount
);
std
::
cout
<<
"test BatchedGEMM fp32: "
<<
(
pass
?
"Pass"
:
"Fail"
)
<<
std
::
endl
;
return
pass
?
0
:
1
;
}
test/batched_gemm/batched_gemm_int8.cpp
deleted
100644 → 0
View file @
ac9595a9
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include "profiler/profile_batched_gemm_impl.hpp"
#include "ck/library/tensor_operation_instance/gpu/batched_gemm.hpp"
namespace
{
using
ADataType
=
int8_t
;
using
BDataType
=
int8_t
;
using
CDataType
=
int8_t
;
using
Row
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
Col
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
}
// namespace
int
main
()
{
int
M
=
256
;
int
N
=
256
;
int
K
=
128
;
int
BatchCount
=
3
;
bool
pass
=
true
;
using
namespace
ck
::
tensor_operation
::
device
;
pass
=
pass
&&
ck
::
profiler
::
profile_batched_gemm_impl
<
ADataType
,
BDataType
,
CDataType
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
DeviceBatchedGemm
<
Row
,
Row
,
Row
,
ADataType
,
BDataType
,
CDataType
,
PassThrough
,
PassThrough
,
PassThrough
>>
(
true
,
1
,
false
,
1
,
M
,
N
,
K
,
K
,
N
,
N
,
M
*
K
,
K
*
N
,
M
*
N
,
BatchCount
);
pass
=
pass
&&
ck
::
profiler
::
profile_batched_gemm_impl
<
ADataType
,
BDataType
,
CDataType
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
DeviceBatchedGemm
<
Row
,
Col
,
Row
,
ADataType
,
BDataType
,
CDataType
,
PassThrough
,
PassThrough
,
PassThrough
>>
(
true
,
1
,
false
,
1
,
M
,
N
,
K
,
K
,
K
,
N
,
M
*
K
,
K
*
N
,
M
*
N
,
BatchCount
);
pass
=
pass
&&
ck
::
profiler
::
profile_batched_gemm_impl
<
ADataType
,
BDataType
,
CDataType
,
Col
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
DeviceBatchedGemm
<
Col
,
Row
,
Row
,
ADataType
,
BDataType
,
CDataType
,
PassThrough
,
PassThrough
,
PassThrough
>>
(
true
,
1
,
false
,
1
,
M
,
N
,
K
,
M
,
N
,
N
,
M
*
K
,
K
*
N
,
M
*
N
,
BatchCount
);
pass
=
pass
&&
ck
::
profiler
::
profile_batched_gemm_impl
<
ADataType
,
BDataType
,
CDataType
,
Col
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
DeviceBatchedGemm
<
Col
,
Col
,
Row
,
ADataType
,
BDataType
,
CDataType
,
PassThrough
,
PassThrough
,
PassThrough
>>
(
true
,
1
,
false
,
1
,
M
,
N
,
K
,
M
,
K
,
N
,
M
*
K
,
K
*
N
,
M
*
N
,
BatchCount
);
std
::
cout
<<
"test BatchedGEMM int8: "
<<
(
pass
?
"Pass"
:
"Fail"
)
<<
std
::
endl
;
return
pass
?
0
:
1
;
}
test/batched_gemm/test_batched_gemm.cpp
0 → 100644
View file @
98c80714
// SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include <iostream>
#include <initializer_list>
#include <tuple>
#include <vector>
#include <gtest/gtest.h>
#include "profiler/profile_batched_gemm_impl.hpp"
#include "ck/library/tensor_operation_instance/gpu/batched_gemm.hpp"
struct
GemmParams
{
ck
::
index_t
M
;
ck
::
index_t
N
;
ck
::
index_t
K
;
ck
::
index_t
BatchCount
;
};
class
TestBatchedGemm
:
public
::
testing
::
Test
{
protected:
using
Row
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
Col
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
std
::
vector
<
GemmParams
>
params
;
template
<
typename
DataType
>
void
Run
()
{
using
namespace
ck
::
tensor_operation
::
device
;
bool
pass
=
true
;
for
(
auto
&
param
:
params
)
{
const
auto
M
=
param
.
M
;
const
auto
N
=
param
.
N
;
const
auto
K
=
param
.
K
;
const
auto
BatchCount
=
param
.
BatchCount
;
pass
=
pass
&&
ck
::
profiler
::
profile_batched_gemm_impl
<
DataType
,
DataType
,
DataType
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
DeviceBatchedGemm
<
Row
,
Row
,
Row
,
DataType
,
DataType
,
DataType
,
PassThrough
,
PassThrough
,
PassThrough
>>
(
true
,
1
,
false
,
1
,
M
,
N
,
K
,
K
,
N
,
N
,
M
*
K
,
K
*
N
,
M
*
N
,
BatchCount
);
pass
=
pass
&&
ck
::
profiler
::
profile_batched_gemm_impl
<
DataType
,
DataType
,
DataType
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
DeviceBatchedGemm
<
Row
,
Col
,
Row
,
DataType
,
DataType
,
DataType
,
PassThrough
,
PassThrough
,
PassThrough
>>
(
true
,
1
,
false
,
1
,
M
,
N
,
K
,
K
,
K
,
N
,
M
*
K
,
K
*
N
,
M
*
N
,
BatchCount
);
pass
=
pass
&&
ck
::
profiler
::
profile_batched_gemm_impl
<
DataType
,
DataType
,
DataType
,
Col
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
DeviceBatchedGemm
<
Col
,
Row
,
Row
,
DataType
,
DataType
,
DataType
,
PassThrough
,
PassThrough
,
PassThrough
>>
(
true
,
1
,
false
,
1
,
M
,
N
,
K
,
M
,
N
,
N
,
M
*
K
,
K
*
N
,
M
*
N
,
BatchCount
);
pass
=
pass
&&
ck
::
profiler
::
profile_batched_gemm_impl
<
DataType
,
DataType
,
DataType
,
Col
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
DeviceBatchedGemm
<
Col
,
Col
,
Row
,
DataType
,
DataType
,
DataType
,
PassThrough
,
PassThrough
,
PassThrough
>>
(
true
,
1
,
false
,
1
,
M
,
N
,
K
,
M
,
K
,
N
,
M
*
K
,
K
*
N
,
M
*
N
,
BatchCount
);
}
EXPECT_TRUE
(
pass
);
}
};
#ifdef CK_ENABLE_INT8
TEST_F
(
TestBatchedGemm
,
i8
)
{
this
->
params
.
push_back
({
64
,
64
,
64
,
2
});
this
->
params
.
push_back
({
64
,
64
,
64
,
1
});
this
->
params
.
push_back
({
60
,
60
,
60
,
2
});
this
->
params
.
push_back
({
68
,
68
,
68
,
2
});
this
->
params
.
push_back
({
40
,
40
,
40
,
2
});
this
->
params
.
push_back
({
256
,
256
,
128
,
3
});
this
->
template
Run
<
int8_t
>();
}
#endif
#ifdef CK_ENABLE_BF16
TEST_F
(
TestBatchedGemm
,
bf16
)
{
this
->
params
.
push_back
({
64
,
64
,
64
,
2
});
this
->
params
.
push_back
({
64
,
64
,
64
,
1
});
this
->
params
.
push_back
({
60
,
60
,
60
,
2
});
this
->
params
.
push_back
({
68
,
68
,
68
,
2
});
this
->
params
.
push_back
({
40
,
40
,
40
,
2
});
this
->
params
.
push_back
({
256
,
256
,
128
,
3
});
this
->
template
Run
<
ck
::
bhalf_t
>();
}
#endif
#ifdef CK_ENABLE_FP16
TEST_F
(
TestBatchedGemm
,
fp16
)
{
this
->
params
.
push_back
({
64
,
64
,
64
,
2
});
this
->
params
.
push_back
({
64
,
64
,
64
,
1
});
this
->
params
.
push_back
({
60
,
60
,
60
,
2
});
this
->
params
.
push_back
({
68
,
68
,
68
,
2
});
this
->
params
.
push_back
({
40
,
40
,
40
,
2
});
this
->
params
.
push_back
({
256
,
256
,
128
,
3
});
this
->
template
Run
<
ck
::
half_t
>();
}
#endif
#ifdef CK_ENABLE_FP32
TEST_F
(
TestBatchedGemm
,
fp32
)
{
this
->
params
.
push_back
({
64
,
64
,
64
,
2
});
this
->
params
.
push_back
({
64
,
64
,
64
,
1
});
this
->
params
.
push_back
({
60
,
60
,
60
,
2
});
this
->
params
.
push_back
({
68
,
68
,
68
,
2
});
this
->
params
.
push_back
({
40
,
40
,
40
,
2
});
this
->
params
.
push_back
({
256
,
256
,
128
,
3
});
this
->
template
Run
<
float
>();
}
#endif
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment