Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
3f74017c
"examples/training/train_unconditional.py" did not exist on "e795a4c6f88a51a5d1b16b47d36f8e103f0a82ac"
Commit
3f74017c
authored
Jun 05, 2024
by
Adam Osewski
Browse files
Clean up.
parent
9177a207
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
14 additions
and
21 deletions
+14
-21
example/15_grouped_gemm/grouped_gemm_multiple_d_splitk_xdl_single_kernel_fp16.cpp
...grouped_gemm_multiple_d_splitk_xdl_single_kernel_fp16.cpp
+14
-21
No files found.
example/15_grouped_gemm/grouped_gemm_multiple_d_splitk_xdl_single_kernel_fp16.cpp
View file @
3f74017c
...
@@ -32,16 +32,16 @@ using Col = ck::tensor_layout::gemm::ColumnMajor;
...
@@ -32,16 +32,16 @@ using Col = ck::tensor_layout::gemm::ColumnMajor;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
ADataType
=
F16
;
using
ADataType
=
F16
;
using
BDataType
=
F16
;
using
BDataType
=
F16
;
using
AccDataType
=
F32
;
using
AccDataType
=
F32
;
using
CShuffleDataType
=
F32
;
using
CShuffleDataType
=
F32
;
using
DsDataType
=
ck
::
Tuple
<>
;
using
DsDataType
=
ck
::
Tuple
<>
;
using
EDataType
=
F16
;
using
EDataType
=
F16
;
using
ALayout
=
Row
;
using
ALayout
=
Row
;
using
BLayout
=
Col
;
using
BLayout
=
Col
;
// using BLayout = Row;
using
DsLayout
=
ck
::
Tuple
<>
;
using
DsLayout
=
ck
::
Tuple
<>
;
using
ELayout
=
Row
;
using
ELayout
=
Row
;
...
@@ -57,9 +57,7 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGroupedGemmMultip
...
@@ -57,9 +57,7 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGroupedGemmMultip
//######| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//######| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//######| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//######| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
<
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
ADataType
,
BDataType
,
AccDataType
,
CShuffleDataType
,
DsDataType
,
EDataType
,
AElementOp
,
BElementOp
,
CDEElementOp
,
GemmMNKPadding
,
1
,
128
,
64
,
64
,
32
,
8
,
8
,
32
,
32
,
2
,
1
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
8
,
ck
::
PipelineVersion
::
v1
>
;
<
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
ADataType
,
BDataType
,
AccDataType
,
CShuffleDataType
,
DsDataType
,
EDataType
,
AElementOp
,
BElementOp
,
CDEElementOp
,
GemmMNKPadding
,
1
,
128
,
64
,
64
,
32
,
8
,
8
,
32
,
32
,
2
,
1
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
2
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
8
,
ck
::
PipelineVersion
::
v1
>
;
// < ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 4>;
// < ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmMNKPadding, 1, 128, 192, 32, 32, 8, 8, 32, 32, 3, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>;
// clang-format on
// clang-format on
struct
ProblemSize
final
struct
ProblemSize
final
...
@@ -79,9 +77,8 @@ struct ExecutionConfig final
...
@@ -79,9 +77,8 @@ struct ExecutionConfig final
{
{
bool
do_verification
=
true
;
bool
do_verification
=
true
;
int
init_method
=
1
;
int
init_method
=
1
;
// int k_batch = 128;
int
k_batch
=
36
;
int
k_batch
=
1
;
bool
time_kernel
=
false
;
bool
time_kernel
=
false
;
};
};
bool
run_grouped_gemm
(
const
ProblemSize
&
problem_size
,
const
ExecutionConfig
&
config
)
bool
run_grouped_gemm
(
const
ProblemSize
&
problem_size
,
const
ExecutionConfig
&
config
)
...
@@ -155,23 +152,22 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
...
@@ -155,23 +152,22 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
{
{
case
0
:
break
;
case
0
:
break
;
case
1
:
case
1
:
a_tensors
[
i
].
GenerateTensorValue
(
GeneratorTensor_2
<
ADataType
>
{
-
5
,
5
});
ck
::
utils
::
FillUniformDistributionIntegerValue
<
ADataType
>
{
-
5
,
5
}
(
a_tensors
[
i
]
);
b_tensors
[
i
].
GenerateTensorValue
(
GeneratorTensor_2
<
BDataType
>
{
-
5
,
5
});
ck
::
utils
::
FillUniformDistributionIntegerValue
<
BDataType
>
{
-
5
,
5
}
(
b_tensors
[
i
]
);
break
;
break
;
case
2
:
case
2
:
a_tensors
[
i
].
GenerateTensorValue
(
GeneratorTensor_3
<
ADataType
>
{
0.0
,
1.0
});
ck
::
utils
::
FillUniformDistribution
<
ADataType
>
{
0.0
,
1.0
}
(
a_tensors
[
i
]
);
b_tensors
[
i
].
GenerateTensorValue
(
GeneratorTensor_3
<
BDataType
>
{
-
0.5
,
0.5
});
ck
::
utils
::
FillUniformDistribution
<
BDataType
>
{
-
0.5
,
0.5
}
(
b_tensors
[
i
]
);
break
;
break
;
case
3
:
case
3
:
ck
::
utils
::
FillConstant
<
ADataType
>
{
1
}(
a_tensors
[
i
]);
ck
::
utils
::
FillConstant
<
ADataType
>
{
1
}(
a_tensors
[
i
]);
ck
::
utils
::
FillConstant
<
BDataType
>
{
1
}(
b_tensors
[
i
]);
ck
::
utils
::
FillConstant
<
BDataType
>
{
1
}(
b_tensors
[
i
]);
break
;
break
;
default:
default:
// a_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<0>{});
// b_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<1>{});
ck
::
utils
::
FillMonotonicSeq
<
ADataType
>
{
0
,
1
}(
a_tensors
[
i
]);
ck
::
utils
::
FillMonotonicSeq
<
ADataType
>
{
0
,
1
}(
a_tensors
[
i
]);
ck
::
utils
::
FillMonotonicSeq
<
BDataType
>
{
1
,
1
}(
b_tensors
[
i
]);
ck
::
utils
::
FillMonotonicSeq
<
BDataType
>
{
1
,
1
}(
b_tensors
[
i
]);
}
}
c_device_tensors
[
i
].
SetZero
();
}
}
using
GroupedGemmKernelArgument
=
using
GroupedGemmKernelArgument
=
...
@@ -319,20 +315,17 @@ int main(int argc, char* argv[])
...
@@ -319,20 +315,17 @@ int main(int argc, char* argv[])
if
(
argc
<
11
)
if
(
argc
<
11
)
{
{
// std::vector<ck::index_t> Ms{64, 127, 255, 129, 260, 190, 77};
std
::
vector
<
ck
::
index_t
>
Ms
{
64
,
127
,
255
,
1
,
129
,
260
,
190
,
77
};
std
::
vector
<
ck
::
index_t
>
Ms
{
64
};
problem_size
.
group_count
=
Ms
.
size
();
problem_size
.
group_count
=
Ms
.
size
();
for
(
int
i
=
0
;
i
<
problem_size
.
group_count
;
i
++
)
for
(
int
i
=
0
;
i
<
problem_size
.
group_count
;
i
++
)
{
{
problem_size
.
Ms
.
push_back
(
Ms
[
i
]);
problem_size
.
Ms
.
push_back
(
Ms
[
i
]);
// problem_size.Ns.push_back(252);
problem_size
.
Ns
.
push_back
(
256
);
problem_size
.
Ns
.
push_back
(
256
);
problem_size
.
Ks
.
push_back
(
4608
);
problem_size
.
Ks
.
push_back
(
4608
);
problem_size
.
stride_As
.
push_back
(
problem_size
.
Ks
[
i
]);
problem_size
.
stride_As
.
push_back
(
problem_size
.
Ks
[
i
]);
problem_size
.
stride_Bs
.
push_back
(
problem_size
.
Ks
[
i
]);
problem_size
.
stride_Bs
.
push_back
(
problem_size
.
Ks
[
i
]);
// problem_size.stride_Bs.push_back(problem_size.Ns[i]);
problem_size
.
stride_Cs
.
push_back
(
problem_size
.
Ns
[
i
]);
problem_size
.
stride_Cs
.
push_back
(
problem_size
.
Ns
[
i
]);
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment