Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
62a860a5
Commit
62a860a5
authored
Jan 04, 2022
by
ltqin
Browse files
change desired gride size to kbatch
parent
accb4ca5
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
46 additions
and
51 deletions
+46
-51
device_operation/include/device_gemm.hpp
device_operation/include/device_gemm.hpp
+13
-14
device_operation/include/device_gemm_splitk_xdl.hpp
device_operation/include/device_gemm_splitk_xdl.hpp
+15
-19
device_operation/include/device_gemm_splitk_xdl_instance.hpp
device_operation/include/device_gemm_splitk_xdl_instance.hpp
+0
-0
profiler/include/profile_gemm_impl.hpp
profiler/include/profile_gemm_impl.hpp
+4
-4
profiler/profile_gemm.cpp
profiler/profile_gemm.cpp
+7
-7
test/split_k/main.cpp
test/split_k/main.cpp
+7
-7
No files found.
device_operation/include/device_gemm.hpp
View file @
62a860a5
...
...
@@ -13,20 +13,19 @@ template <typename AElementwiseOperation,
typename
CElementwiseOperation
>
struct
DeviceGemm
:
public
BaseOperator
{
virtual
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
void
*
p_a
,
const
void
*
p_b
,
void
*
p_c
,
ck
::
index_t
M
,
ck
::
index_t
N
,
ck
::
index_t
K
,
ck
::
index_t
StrideA
,
ck
::
index_t
StrideB
,
ck
::
index_t
StrideC
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CElementwiseOperation
c_element_op
,
ck
::
index_t
desired_gride_size
=
1
)
=
0
;
virtual
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
void
*
p_a
,
const
void
*
p_b
,
void
*
p_c
,
ck
::
index_t
M
,
ck
::
index_t
N
,
ck
::
index_t
K
,
ck
::
index_t
StrideA
,
ck
::
index_t
StrideB
,
ck
::
index_t
StrideC
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CElementwiseOperation
c_element_op
,
ck
::
index_t
KBatch
=
1
)
=
0
;
virtual
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
=
0
;
};
...
...
device_operation/include/device_gemm_splitk_xdl.hpp
View file @
62a860a5
...
...
@@ -144,13 +144,11 @@ struct DeviceGemmSplitKXdl
}
}
static
auto
GetK
BatchAndK
Pad
(
index_t
M
,
index_t
N
,
index_t
K
,
index_t
DesiredGridSize
)
static
auto
GetKPad
(
index_t
K
,
index_t
KBatch
)
{
const
auto
GridMN
=
M
*
N
/
(
MPerBlock
*
NPerBlock
);
const
index_t
KBatch
=
std
::
max
(
DesiredGridSize
/
GridMN
,
1
);
const
index_t
K0
=
math
::
integer_divide_ceil
(
K
,
K1
*
K0PerBlock
*
KBatch
)
*
K0PerBlock
;
const
index_t
KPad
=
KBatch
*
K0
*
K1
;
return
std
::
make_tuple
(
KBatch
,
KPad
);
const
index_t
K0
=
math
::
integer_divide_ceil
(
K
,
K1
*
K0PerBlock
*
KBatch
)
*
K0PerBlock
;
const
index_t
KPad
=
KBatch
*
K0
*
K1
;
return
KPad
;
}
using
AGridDesc_K0_M_K1
=
decltype
(
MakeAGridDescriptor_KBatch_K0_M_K1
(
1
,
1
,
1
,
1
,
1
));
...
...
@@ -262,7 +260,7 @@ struct DeviceGemmSplitKXdl
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CElementwiseOperation
c_element_op
,
index_t
desired_grid_size
)
index_t
k_batch
)
:
p_a_grid_
{
p_a_grid
},
p_b_grid_
{
p_b_grid
},
p_c_grid_
{
p_c_grid
},
...
...
@@ -276,16 +274,14 @@ struct DeviceGemmSplitKXdl
a_element_op_
{
a_element_op
},
b_element_op_
{
b_element_op
},
c_element_op_
{
c_element_op
},
desired_grid_size_
{
desired_grid_size
}
k_batch_
{
k_batch
}
{
int
KBatch
=
1
,
KPad
=
K
;
std
::
tie
(
KBatch
,
KPad
)
=
DeviceGemmSplitKXdl
::
GetKBatchAndKPad
(
M
,
N
,
K
,
desired_grid_size_
);
int
KPad
=
DeviceGemmSplitKXdl
::
GetKPad
(
K
,
k_batch_
);
a_grid_desc_kbatch_k0_m_k1_
=
DeviceGemmSplitKXdl
::
MakeAGridDescriptor_KBatch_K0_M_K1
(
M
,
K
,
StrideA
,
KB
atch
,
KPad
);
M
,
K
,
StrideA
,
k_b
atch
_
,
KPad
);
b_grid_desc_kbatch_k0_n_k1_
=
DeviceGemmSplitKXdl
::
MakeBGridDescriptor_KBatch_K0_N_K1
(
K
,
N
,
StrideB
,
KB
atch
,
KPad
);
K
,
N
,
StrideB
,
k_b
atch
_
,
KPad
);
c_grid_desc_m_n_
=
DeviceGemmSplitKXdl
::
MakeCGridDescriptor_M_N
(
M
,
N
,
StrideC
);
if
(
GridwiseGemm
::
CheckValidity
(
a_grid_desc_kbatch_k0_m_k1_
,
...
...
@@ -298,7 +294,7 @@ struct DeviceGemmSplitKXdl
GridwiseGemm
::
MakeCM0N0M1N1M2M3M4N2GridDescriptor
(
c_grid_desc_m_n_
);
block_2_ctile_map_
=
GridwiseGemm
::
MakeCBlockClusterAdaptor
(
c_grid_desc_m_n_
,
M01
,
N01
,
KB
atch
);
GridwiseGemm
::
MakeCBlockClusterAdaptor
(
c_grid_desc_m_n_
,
M01
,
N01
,
k_b
atch
_
);
}
}
...
...
@@ -316,7 +312,7 @@ struct DeviceGemmSplitKXdl
AElementwiseOperation
a_element_op_
;
BElementwiseOperation
b_element_op_
;
CElementwiseOperation
c_element_op_
;
index_t
desired_grid_size
_
;
index_t
k_batch
_
;
};
// Invoker
...
...
@@ -526,7 +522,7 @@ struct DeviceGemmSplitKXdl
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CElementwiseOperation
c_element_op
,
index_t
desired_grid_Size
)
index_t
KBatch
)
{
return
Argument
{
p_a
,
p_b
,
...
...
@@ -542,7 +538,7 @@ struct DeviceGemmSplitKXdl
a_element_op
,
b_element_op
,
c_element_op
,
desired_grid_Size
};
KBatch
};
}
static
auto
MakeInvoker
()
{
return
Invoker
{};
}
...
...
@@ -560,7 +556,7 @@ struct DeviceGemmSplitKXdl
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CElementwiseOperation
c_element_op
,
ck
::
index_t
desired_gride_size
=
1
)
override
ck
::
index_t
KBatch
=
1
)
override
{
return
std
::
make_unique
<
Argument
>
(
static_cast
<
const
ADataType
*>
(
p_a
),
static_cast
<
const
BDataType
*>
(
p_b
),
...
...
@@ -576,7 +572,7 @@ struct DeviceGemmSplitKXdl
a_element_op
,
b_element_op
,
c_element_op
,
desired_gride_size
);
KBatch
);
}
// polymorphic
...
...
device_operation/include/device_gemm_
xdl_
splitk_instance.hpp
→
device_operation/include/device_gemm_splitk_
xdl_
instance.hpp
View file @
62a860a5
File moved
profiler/include/profile_gemm_impl.hpp
View file @
62a860a5
#pragma once
#include "device_gemm_instance.hpp"
#include "device_gemm_
xdl_
splitk_instance.hpp"
#include "device_gemm_splitk_
xdl_
instance.hpp"
namespace
ck
{
namespace
tensor_operation
{
...
...
@@ -95,7 +95,7 @@ void profile_gemm_impl(int do_verification,
int
StrideA
,
int
StrideB
,
int
StrideC
,
int
DesiredGridSize
=
1
)
int
KBatch
=
1
)
{
auto
f_host_tensor_descriptor
=
[](
std
::
size_t
row
,
std
::
size_t
col
,
std
::
size_t
stride
,
auto
layout
)
{
...
...
@@ -156,7 +156,7 @@ void profile_gemm_impl(int do_verification,
// add device GEMM instances
std
::
vector
<
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
DeviceGemmNoOpPtr
>
gemm_ptrs
;
if
(
DesiredGridSize
>
1
&&
is_same
<
ADataType
,
float
>::
value
)
if
(
KBatch
>
1
&&
is_same
<
ADataType
,
float
>::
value
)
{
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
add_device_splitk_gemm_instance
<
float
,
float
,
float
,
ALayout
,
BLayout
,
CLayout
>
(
...
...
@@ -195,7 +195,7 @@ void profile_gemm_impl(int do_verification,
ck
::
tensor_operation
::
element_wise
::
PassThrough
{},
ck
::
tensor_operation
::
element_wise
::
PassThrough
{},
ck
::
tensor_operation
::
element_wise
::
PassThrough
{},
DesiredGridSize
);
KBatch
);
auto
invoker_ptr
=
gemm_ptr
->
MakeInvokerPointer
();
...
...
profiler/profile_gemm.cpp
View file @
62a860a5
...
...
@@ -48,7 +48,7 @@ int profile_gemm(int argc, char* argv[])
printf
(
"arg8: print tensor value (0: no; 1: yes)
\n
"
);
printf
(
"arg7: run kernel # of times (>1)
\n
"
);
printf
(
"arg8 to 13: M, N, K, StrideA, StrideB, StrideC
\n
"
);
printf
(
"arg14:
desired grid size
\n
"
);
printf
(
"arg14:
split k into mulitiple batch
\n
"
);
exit
(
1
);
}
...
...
@@ -66,9 +66,9 @@ int profile_gemm(int argc, char* argv[])
const
int
StrideA
=
std
::
stoi
(
argv
[
11
]);
const
int
StrideB
=
std
::
stoi
(
argv
[
12
]);
const
int
StrideC
=
std
::
stoi
(
argv
[
13
]);
int
DesiredGridSize
=
1
;
int
KBatch
=
1
;
if
(
argc
==
15
)
DesiredGridSize
=
std
::
stoi
(
argv
[
14
]);
KBatch
=
std
::
stoi
(
argv
[
14
]);
if
(
data_type
==
GemmDataType
::
F16_F16_F16
&&
layout
==
GemmMatrixLayout
::
MK_KN_MN
)
{
...
...
@@ -164,7 +164,7 @@ int profile_gemm(int argc, char* argv[])
(
StrideA
<
0
)
?
K
:
StrideA
,
(
StrideB
<
0
)
?
N
:
StrideB
,
(
StrideC
<
0
)
?
N
:
StrideC
,
DesiredGridSize
);
KBatch
);
}
else
if
(
data_type
==
GemmDataType
::
F32_F32_F32
&&
layout
==
GemmMatrixLayout
::
MK_NK_MN
)
{
...
...
@@ -184,7 +184,7 @@ int profile_gemm(int argc, char* argv[])
(
StrideA
<
0
)
?
K
:
StrideA
,
(
StrideB
<
0
)
?
K
:
StrideB
,
(
StrideC
<
0
)
?
N
:
StrideC
,
DesiredGridSize
);
KBatch
);
}
else
if
(
data_type
==
GemmDataType
::
F32_F32_F32
&&
layout
==
GemmMatrixLayout
::
KM_KN_MN
)
{
...
...
@@ -204,7 +204,7 @@ int profile_gemm(int argc, char* argv[])
(
StrideA
<
0
)
?
M
:
StrideA
,
(
StrideB
<
0
)
?
N
:
StrideB
,
(
StrideC
<
0
)
?
N
:
StrideC
,
DesiredGridSize
);
KBatch
);
}
else
if
(
data_type
==
GemmDataType
::
F32_F32_F32
&&
layout
==
GemmMatrixLayout
::
KM_NK_MN
)
{
...
...
@@ -224,7 +224,7 @@ int profile_gemm(int argc, char* argv[])
(
StrideA
<
0
)
?
M
:
StrideA
,
(
StrideB
<
0
)
?
K
:
StrideB
,
(
StrideC
<
0
)
?
N
:
StrideC
,
DesiredGridSize
);
KBatch
);
}
else
{
...
...
test/split_k/main.cpp
View file @
62a860a5
...
...
@@ -11,7 +11,7 @@
#include "device_gemm_instance.hpp"
#include "host_gemm.hpp"
#include "tensor_layout.hpp"
#include "device_gemm_xdl_instance.hpp"
#include "device_gemm_
splitk_
xdl_instance.hpp"
#include "device_gemm_splitk_xdl.hpp"
enum
GemmMatrixLayout
...
...
@@ -112,7 +112,7 @@ int main(int argc, char* argv[])
printf
(
" 1: A[m, k] * B[n, k] = C[m, n];
\n
"
);
printf
(
" 2: A[k, n] * B[k, n] = C[m, n];
\n
"
);
printf
(
" 3: A[k, n] * B[n, k] = C[m, n])
\n
"
);
printf
(
"arg2 to 7: M, N, K, StrideA, StrideB, StrideC
DesiredGridSize
\n
"
);
printf
(
"arg2 to 7: M, N, K, StrideA, StrideB, StrideC
KBatch
\n
"
);
return
1
;
}
...
...
@@ -122,10 +122,10 @@ int main(int argc, char* argv[])
const
int
N
=
std
::
stoi
(
argv
[
3
]);
const
int
K
=
std
::
stoi
(
argv
[
4
]);
const
int
StrideA
=
std
::
stoi
(
argv
[
5
]);
const
int
StrideB
=
std
::
stoi
(
argv
[
6
]);
const
int
StrideC
=
std
::
stoi
(
argv
[
7
]);
const
int
DesiredGridSize
=
std
::
stoi
(
argv
[
8
]);
const
int
StrideA
=
std
::
stoi
(
argv
[
5
]);
const
int
StrideB
=
std
::
stoi
(
argv
[
6
]);
const
int
StrideC
=
std
::
stoi
(
argv
[
7
]);
const
int
KBatch
=
std
::
stoi
(
argv
[
8
]);
if
(
layout
>
3
||
layout
<
0
)
{
...
...
@@ -194,7 +194,7 @@ int main(int argc, char* argv[])
ck
::
tensor_operation
::
element_wise
::
PassThrough
{},
ck
::
tensor_operation
::
element_wise
::
PassThrough
{},
ck
::
tensor_operation
::
element_wise
::
PassThrough
{},
DesiredGridSize
);
KBatch
);
auto
invoker_ptr
=
gemm_ptr
->
MakeInvokerPointer
();
if
(
gemm_ptr
->
IsSupportedArgument
(
argument_ptr
.
get
()))
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment