Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
0e67221f
Commit
0e67221f
authored
Jan 27, 2022
by
Chao Liu
Browse files
format
parent
fe027ba3
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
11 additions
and
8 deletions
+11
-8
test/split_k/main.cpp
test/split_k/main.cpp
+11
-8
No files found.
test/split_k/main.cpp
View file @
0e67221f
...
@@ -21,18 +21,19 @@ enum GemmMatrixLayout
...
@@ -21,18 +21,19 @@ enum GemmMatrixLayout
KM_KN_MN
,
// 2
KM_KN_MN
,
// 2
KM_NK_MN
,
// 3
KM_NK_MN
,
// 3
};
};
using
DeviceGemmNoOpPtr
=
using
DeviceGemmNoOpPtr
=
ck
::
tensor_operation
::
device
::
DeviceGemmPtr
<
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
device
::
DeviceGemmPtr
<
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
>
;
ck
::
tensor_operation
::
element_wise
::
PassThrough
>
;
using
GEMM_PTR
=
std
::
vector
<
DeviceGemmNoOpPtr
>
;
static
std
::
vector
<
std
::
vector
<
bool
>>&
GetLayoutType
()
static
std
::
vector
<
std
::
vector
<
bool
>>&
GetLayoutType
()
{
{
static
std
::
vector
<
std
::
vector
<
bool
>>
LayOut
=
{{
0
,
0
,
0
},
{
0
,
1
,
0
},
{
1
,
0
,
0
},
{
1
,
1
,
0
}};
static
std
::
vector
<
std
::
vector
<
bool
>>
LayOut
=
{{
0
,
0
,
0
},
{
0
,
1
,
0
},
{
1
,
0
,
0
},
{
1
,
1
,
0
}};
return
LayOut
;
return
LayOut
;
}
}
static
void
add_device_gemm_instance_mk_kn_mn
(
GEMM_PTR
&
gemm_ptrs
)
static
void
add_device_gemm_instance_mk_kn_mn
(
std
::
vector
<
DeviceGemmNoOpPtr
>&
gemm_ptrs
)
{
{
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
add_device_splitk_gemm_instance
<
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
add_device_splitk_gemm_instance
<
float
,
float
,
...
@@ -42,7 +43,8 @@ static void add_device_gemm_instance_mk_kn_mn(GEMM_PTR& gemm_ptrs)
...
@@ -42,7 +43,8 @@ static void add_device_gemm_instance_mk_kn_mn(GEMM_PTR& gemm_ptrs)
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
>
(
gemm_ptrs
);
ck
::
tensor_layout
::
gemm
::
RowMajor
>
(
gemm_ptrs
);
}
}
static
void
add_device_gemm_instance_mk_nk_mn
(
GEMM_PTR
&
gemm_ptrs
)
static
void
add_device_gemm_instance_mk_nk_mn
(
std
::
vector
<
DeviceGemmNoOpPtr
>&
gemm_ptrs
)
{
{
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
add_device_splitk_gemm_instance
<
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
add_device_splitk_gemm_instance
<
float
,
float
,
...
@@ -52,7 +54,7 @@ static void add_device_gemm_instance_mk_nk_mn(GEMM_PTR& gemm_ptrs)
...
@@ -52,7 +54,7 @@ static void add_device_gemm_instance_mk_nk_mn(GEMM_PTR& gemm_ptrs)
ck
::
tensor_layout
::
gemm
::
ColumnMajor
,
ck
::
tensor_layout
::
gemm
::
ColumnMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
>
(
gemm_ptrs
);
ck
::
tensor_layout
::
gemm
::
RowMajor
>
(
gemm_ptrs
);
}
}
static
void
add_device_gemm_instance_km_kn_mn
(
GEMM_PTR
&
gemm_ptrs
)
static
void
add_device_gemm_instance_km_kn_mn
(
std
::
vector
<
DeviceGemmNoOpPtr
>
&
gemm_ptrs
)
{
{
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
add_device_splitk_gemm_instance
<
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
add_device_splitk_gemm_instance
<
float
,
float
,
...
@@ -62,7 +64,7 @@ static void add_device_gemm_instance_km_kn_mn(GEMM_PTR& gemm_ptrs)
...
@@ -62,7 +64,7 @@ static void add_device_gemm_instance_km_kn_mn(GEMM_PTR& gemm_ptrs)
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
>
(
gemm_ptrs
);
ck
::
tensor_layout
::
gemm
::
RowMajor
>
(
gemm_ptrs
);
}
}
static
void
add_device_gemm_instance_km_nk_mn
(
GEMM_PTR
&
gemm_ptrs
)
static
void
add_device_gemm_instance_km_nk_mn
(
std
::
vector
<
DeviceGemmNoOpPtr
>
&
gemm_ptrs
)
{
{
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
add_device_splitk_gemm_instance
<
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
add_device_splitk_gemm_instance
<
float
,
float
,
...
@@ -75,7 +77,7 @@ static void add_device_gemm_instance_km_nk_mn(GEMM_PTR& gemm_ptrs)
...
@@ -75,7 +77,7 @@ static void add_device_gemm_instance_km_nk_mn(GEMM_PTR& gemm_ptrs)
static
auto
&
GetAddDeviceGemmInstance
()
static
auto
&
GetAddDeviceGemmInstance
()
{
{
static
std
::
vector
<
void
(
*
)(
GEMM_PTR
&
)
>
AddDeviceGemmInstance
=
{
static
std
::
vector
<
void
(
*
)(
std
::
vector
<
DeviceGemmNoOpPtr
>
&
)
>
AddDeviceGemmInstance
=
{
add_device_gemm_instance_mk_kn_mn
,
add_device_gemm_instance_mk_kn_mn
,
add_device_gemm_instance_mk_nk_mn
,
add_device_gemm_instance_mk_nk_mn
,
add_device_gemm_instance_km_kn_mn
,
add_device_gemm_instance_km_kn_mn
,
...
@@ -83,7 +85,7 @@ static auto& GetAddDeviceGemmInstance()
...
@@ -83,7 +85,7 @@ static auto& GetAddDeviceGemmInstance()
return
AddDeviceGemmInstance
;
return
AddDeviceGemmInstance
;
}
}
static
void
add_device_gemm_instance
(
GEMM_PTR
&
gemm_ptrs
,
int
layout
)
static
void
add_device_gemm_instance
(
std
::
vector
<
DeviceGemmNoOpPtr
>
&
gemm_ptrs
,
int
layout
)
{
{
GetAddDeviceGemmInstance
()[
layout
](
gemm_ptrs
);
GetAddDeviceGemmInstance
()[
layout
](
gemm_ptrs
);
}
}
...
@@ -104,6 +106,7 @@ static bool check_out(const Tensor<T>& ref, const Tensor<T>& result)
...
@@ -104,6 +106,7 @@ static bool check_out(const Tensor<T>& ref, const Tensor<T>& result)
return
true
;
return
true
;
}
}
int
main
(
int
argc
,
char
*
argv
[])
int
main
(
int
argc
,
char
*
argv
[])
{
{
if
(
argc
!=
9
)
if
(
argc
!=
9
)
...
@@ -175,7 +178,7 @@ int main(int argc, char* argv[])
...
@@ -175,7 +178,7 @@ int main(int argc, char* argv[])
c_device_buf
.
ToDevice
(
c_m_n_device_result
.
mData
.
data
());
c_device_buf
.
ToDevice
(
c_m_n_device_result
.
mData
.
data
());
// add device GEMM instances
// add device GEMM instances
GEMM_PTR
gemm_ptrs
;
std
::
vector
<
DeviceGemmNoOpPtr
>
gemm_ptrs
;
add_device_gemm_instance
(
gemm_ptrs
,
layout
);
add_device_gemm_instance
(
gemm_ptrs
,
layout
);
bool
success
=
false
;
bool
success
=
false
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment