Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
5d61cd96
Commit
5d61cd96
authored
Jul 19, 2023
by
Jing Zhang
Browse files
add grouped_gemm_bias example
parent
c0c3e21e
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
401 additions
and
85 deletions
+401
-85
example/15_grouped_gemm/CMakeLists.txt
example/15_grouped_gemm/CMakeLists.txt
+2
-1
example/15_grouped_gemm/grouped_gemm_xdl_fixed_nk_bias_fp16.cpp
...e/15_grouped_gemm/grouped_gemm_xdl_fixed_nk_bias_fp16.cpp
+372
-0
example/15_grouped_gemm/grouped_gemm_xdl_fixed_nk_fp16.cpp
example/15_grouped_gemm/grouped_gemm_xdl_fixed_nk_fp16.cpp
+1
-1
include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_fixed_nk.hpp
...tion/gpu/device/impl/device_grouped_gemm_xdl_fixed_nk.hpp
+20
-17
include/ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp
...r_operation/gpu/element/binary_element_wise_operation.hpp
+0
-66
include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp
...or_operation/gpu/element/unary_element_wise_operation.hpp
+6
-0
No files found.
example/15_grouped_gemm/CMakeLists.txt
View file @
5d61cd96
...
@@ -8,6 +8,7 @@ add_example_executable(example_grouped_gemm_multiple_d_dl_fp16 grouped_gemm_mult
...
@@ -8,6 +8,7 @@ add_example_executable(example_grouped_gemm_multiple_d_dl_fp16 grouped_gemm_mult
add_example_executable
(
example_grouped_gemm_xdl_splitk_fp16 grouped_gemm_xdl_splitk_fp16.cpp
)
add_example_executable
(
example_grouped_gemm_xdl_splitk_fp16 grouped_gemm_xdl_splitk_fp16.cpp
)
add_example_executable
(
example_grouped_gemm_xdl_fixed_nk_fp16 grouped_gemm_xdl_fixed_nk_fp16.cpp
)
add_example_executable
(
example_grouped_gemm_xdl_fixed_nk_fp16 grouped_gemm_xdl_fixed_nk_fp16.cpp
)
add_example_executable
(
example_grouped_gemm_xdl_fixed_nk_bias_fp16 grouped_gemm_xdl_fixed_nk_bias_fp16.cpp
)
add_dependencies
(
example_grouped_gemm_xdl
add_dependencies
(
example_grouped_gemm_xdl
...
@@ -17,7 +18,7 @@ add_dependencies(example_grouped_gemm_xdl
...
@@ -17,7 +18,7 @@ add_dependencies(example_grouped_gemm_xdl
example_grouped_gemm_xdl_int8
example_grouped_gemm_xdl_int8
example_grouped_gemm_multiple_d_dl_fp16
example_grouped_gemm_multiple_d_dl_fp16
example_grouped_gemm_xdl_splitk_fp16
example_grouped_gemm_xdl_splitk_fp16
example_grouped_gemm_xdl_fixed_nk_fp16
example_grouped_gemm_xdl_fixed_nk_
bias_
fp16
)
)
if
(
USE_BITINT_EXTENSION_INT4
)
if
(
USE_BITINT_EXTENSION_INT4
)
...
...
example/15_grouped_gemm/grouped_gemm_xdl_fixed_nk_bias_fp16.cpp
0 → 100644
View file @
5d61cd96
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_fixed_nk.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_gemm.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/literals.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
template
<
ck
::
index_t
...
Is
>
using
S
=
ck
::
Sequence
<
Is
...
>
;
using
F16
=
ck
::
half_t
;
using
F32
=
float
;
using
Row
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
Col
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
ADataType
=
F16
;
using
BDataType
=
F16
;
using
AccDataType
=
F32
;
using
CShuffleDataType
=
F32
;
using
D0DataType
=
F32
;
using
DsDataType
=
ck
::
Tuple
<
D0DataType
>
;
using
EDataType
=
F16
;
using
ALayout
=
Row
;
using
BLayout
=
Col
;
using
D0Layout
=
Row
;
using
DsLayout
=
ck
::
Tuple
<
D0Layout
>
;
using
ELayout
=
Row
;
using
AElementOp
=
PassThrough
;
using
BElementOp
=
PassThrough
;
struct
Add
{
template
<
typename
E
,
typename
C
,
typename
D0
>
__host__
__device__
void
operator
()(
E
&
e
,
const
C
&
c
,
const
D0
&
d0
)
const
;
template
<
>
__host__
__device__
void
operator
()
<
ck
::
half_t
,
float
,
float
>
(
ck
::
half_t
&
e
,
const
float
&
c
,
const
float
&
d0
)
const
{
e
=
c
+
d0
;
}
template
<
>
__host__
__device__
void
operator
()
<
ck
::
half_t
,
ck
::
half_t
,
float
>
(
ck
::
half_t
&
e
,
const
ck
::
half_t
&
c
,
const
float
&
d0
)
const
{
e
=
c
+
d0
;
}
};
using
CDEElementOp
=
Add
;
static
constexpr
auto
GemmDefault
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
MNPadding
;
using
DeviceGemmInstance
=
ck
::
tensor_operation
::
device
::
DeviceGroupedGemm_Xdl_Fixed_NK
// clang-format off
//######| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//######| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//######| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
<
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
ADataType
,
BDataType
,
AccDataType
,
CShuffleDataType
,
DsDataType
,
EDataType
,
AElementOp
,
BElementOp
,
CDEElementOp
,
GemmDefault
,
1
,
256
,
256
,
128
,
32
,
8
,
8
,
32
,
32
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
;
// clang-format on
struct
ProblemSize
final
{
std
::
vector
<
ck
::
index_t
>
Ms
;
std
::
vector
<
ck
::
index_t
>
Ns
;
std
::
vector
<
ck
::
index_t
>
Ks
;
std
::
vector
<
ck
::
index_t
>
stride_As
;
std
::
vector
<
ck
::
index_t
>
stride_Bs
;
std
::
vector
<
ck
::
index_t
>
stride_Cs
;
ck
::
index_t
group_count
;
};
struct
ExecutionConfig
final
{
bool
do_verification
=
true
;
int
init_method
=
1
;
bool
time_kernel
=
false
;
};
bool
run_grouped_gemm
(
const
ProblemSize
&
problem_size
,
const
ExecutionConfig
&
config
)
{
auto
group_count
=
problem_size
.
group_count
;
// GEMM shape
std
::
vector
<
ck
::
tensor_operation
::
device
::
GemmDesc
>
gemm_descs
;
std
::
vector
<
std
::
array
<
const
void
*
,
1
>>
p_Ds
;
std
::
vector
<
void
*>
p_Cs
;
gemm_descs
.
reserve
(
group_count
);
int
sum_of_m
=
0
;
auto
f_host_tensor_descriptor
=
[](
std
::
size_t
row
,
std
::
size_t
col
,
std
::
size_t
stride
,
auto
layout
)
{
using
namespace
ck
::
literals
;
if
(
std
::
is_same
<
decltype
(
layout
),
ck
::
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
return
HostTensorDescriptor
({
row
,
col
},
{
stride
,
1
_uz
});
}
else
{
return
HostTensorDescriptor
({
row
,
col
},
{
1
_uz
,
stride
});
}
};
std
::
vector
<
Tensor
<
ADataType
>>
a_tensors
;
std
::
vector
<
Tensor
<
BDataType
>>
b_tensors
;
std
::
vector
<
Tensor
<
D0DataType
>>
d0_tensors
;
std
::
vector
<
Tensor
<
EDataType
>>
c_host_tensors
;
std
::
vector
<
Tensor
<
EDataType
>>
c_device_tensors
;
a_tensors
.
reserve
(
group_count
);
b_tensors
.
reserve
(
group_count
);
d0_tensors
.
reserve
(
group_count
);
c_host_tensors
.
reserve
(
group_count
);
c_device_tensors
.
reserve
(
group_count
);
using
DeviceMemPtr
=
std
::
unique_ptr
<
DeviceMem
>
;
std
::
vector
<
DeviceMemPtr
>
a_tensors_device
,
b_tensors_device
,
d0_tensors_device
,
c_tensors_device
;
a_tensors_device
.
reserve
(
group_count
);
b_tensors_device
.
reserve
(
group_count
);
d0_tensors_device
.
reserve
(
group_count
);
c_tensors_device
.
reserve
(
group_count
);
std
::
size_t
flop
=
0
,
num_btype
=
0
;
for
(
int
i
=
0
;
i
<
group_count
;
i
++
)
{
sum_of_m
+=
problem_size
.
Ms
[
i
];
a_tensors
.
push_back
(
Tensor
<
ADataType
>
(
f_host_tensor_descriptor
(
problem_size
.
Ms
[
i
],
problem_size
.
Ks
[
i
],
problem_size
.
stride_As
[
i
],
ALayout
{})));
b_tensors
.
push_back
(
Tensor
<
BDataType
>
(
f_host_tensor_descriptor
(
problem_size
.
Ks
[
i
],
problem_size
.
Ns
[
i
],
problem_size
.
stride_Bs
[
i
],
BLayout
{})));
d0_tensors
.
push_back
(
Tensor
<
D0DataType
>
(
f_host_tensor_descriptor
(
problem_size
.
Ms
[
i
],
problem_size
.
Ns
[
i
],
0
,
ELayout
{})));
c_host_tensors
.
push_back
(
Tensor
<
EDataType
>
(
f_host_tensor_descriptor
(
problem_size
.
Ms
[
i
],
problem_size
.
Ns
[
i
],
problem_size
.
stride_Cs
[
i
],
ELayout
{})));
c_device_tensors
.
push_back
(
Tensor
<
EDataType
>
(
f_host_tensor_descriptor
(
problem_size
.
Ms
[
i
],
problem_size
.
Ns
[
i
],
problem_size
.
stride_Cs
[
i
],
ELayout
{})));
std
::
cout
<<
"gemm["
<<
i
<<
"] a_m_k: "
<<
a_tensors
[
i
].
mDesc
<<
" b_k_n: "
<<
b_tensors
[
i
].
mDesc
<<
" d_m_n: "
<<
d0_tensors
[
i
].
mDesc
<<
" c_m_n: "
<<
c_device_tensors
[
i
].
mDesc
<<
std
::
endl
;
flop
+=
std
::
size_t
(
2
)
*
problem_size
.
Ms
[
i
]
*
problem_size
.
Ks
[
i
]
*
problem_size
.
Ns
[
i
];
num_btype
+=
sizeof
(
ADataType
)
*
a_tensors
[
i
].
mDesc
.
GetElementSize
()
+
sizeof
(
BDataType
)
*
b_tensors
[
i
].
mDesc
.
GetElementSize
()
+
sizeof
(
D0DataType
)
*
d0_tensors
[
i
].
mDesc
.
GetElementSize
()
+
sizeof
(
EDataType
)
*
c_device_tensors
[
i
].
mDesc
.
GetElementSize
();
switch
(
config
.
init_method
)
{
case
0
:
break
;
case
1
:
a_tensors
[
i
].
GenerateTensorValue
(
GeneratorTensor_2
<
ADataType
>
{
-
5
,
5
});
b_tensors
[
i
].
GenerateTensorValue
(
GeneratorTensor_2
<
BDataType
>
{
-
5
,
5
});
break
;
case
2
:
a_tensors
[
i
].
GenerateTensorValue
(
GeneratorTensor_3
<
ADataType
>
{
0.0
,
1.0
});
b_tensors
[
i
].
GenerateTensorValue
(
GeneratorTensor_3
<
BDataType
>
{
-
0.5
,
0.5
});
break
;
default:
a_tensors
[
i
].
GenerateTensorValue
(
GeneratorTensor_Sequential
<
0
>
{});
b_tensors
[
i
].
GenerateTensorValue
(
GeneratorTensor_Sequential
<
1
>
{});
}
d0_tensors
[
i
].
GenerateTensorValue
(
GeneratorTensor_3
<
BDataType
>
{
-
0.5
,
0.5
});
}
using
GroupedGemmKernelArgument
=
ck
::
tensor_operation
::
device
::
GroupedGemmKernelArgument
<
1
>
;
std
::
vector
<
GroupedGemmKernelArgument
>
grouped_gemm_kernel_args_
;
grouped_gemm_kernel_args_
.
reserve
(
group_count
);
for
(
int
i
=
0
;
i
<
group_count
;
i
++
)
{
a_tensors_device
.
emplace_back
(
std
::
make_unique
<
DeviceMem
>
(
sizeof
(
ADataType
)
*
sum_of_m
*
problem_size
.
Ks
[
i
]));
b_tensors_device
.
emplace_back
(
std
::
make_unique
<
DeviceMem
>
(
sizeof
(
BDataType
)
*
problem_size
.
Ns
[
i
]
*
problem_size
.
Ks
[
i
]));
d0_tensors_device
.
emplace_back
(
std
::
make_unique
<
DeviceMem
>
(
sizeof
(
D0DataType
)
*
problem_size
.
Ns
[
i
]));
c_tensors_device
.
emplace_back
(
std
::
make_unique
<
DeviceMem
>
(
sizeof
(
EDataType
)
*
sum_of_m
*
problem_size
.
Ns
[
i
]));
a_tensors_device
[
i
]
->
ToDevice
(
a_tensors
[
i
].
mData
.
data
(),
a_tensors
[
i
].
mDesc
.
GetElementSpaceSize
()
*
sizeof
(
ADataType
));
b_tensors_device
[
i
]
->
ToDevice
(
b_tensors
[
i
].
mData
.
data
(),
b_tensors
[
i
].
mDesc
.
GetElementSpaceSize
()
*
sizeof
(
BDataType
));
d0_tensors_device
[
i
]
->
ToDevice
(
d0_tensors
[
i
].
mData
.
data
());
c_tensors_device
[
i
]
->
SetZero
();
p_Ds
.
push_back
(
std
::
array
<
const
void
*
,
1
>
{
d0_tensors_device
[
i
]
->
GetDeviceBuffer
()});
p_Cs
.
push_back
(
c_tensors_device
[
i
]
->
GetDeviceBuffer
());
gemm_descs
.
push_back
({
sum_of_m
,
problem_size
.
Ns
[
i
],
problem_size
.
Ks
[
i
],
problem_size
.
stride_As
[
i
],
problem_size
.
stride_Bs
[
i
],
problem_size
.
stride_Cs
[
i
],
{
0
}});
grouped_gemm_kernel_args_
.
push_back
(
{
a_tensors_device
[
i
]
->
GetDeviceBuffer
(),
b_tensors_device
[
i
]
->
GetDeviceBuffer
(),
std
::
array
<
const
void
*
,
1
>
{
d0_tensors_device
[
i
]
->
GetDeviceBuffer
()},
c_tensors_device
[
i
]
->
GetDeviceBuffer
(),
problem_size
.
Ms
[
i
],
problem_size
.
Ns
[
i
],
problem_size
.
Ks
[
i
],
problem_size
.
stride_As
[
i
],
problem_size
.
stride_Bs
[
i
],
std
::
array
<
ck
::
index_t
,
1
>
{
0
},
problem_size
.
stride_Cs
[
i
]});
}
auto
a_element_op
=
AElementOp
{};
auto
b_element_op
=
BElementOp
{};
auto
cde_element_op
=
CDEElementOp
{};
auto
gemm
=
DeviceGemmInstance
{};
auto
invoker
=
gemm
.
MakeInvoker
();
std
::
vector
<
const
void
*>
p_As
=
{};
std
::
vector
<
const
void
*>
p_Bs
=
{};
// do GEMM
auto
argument
=
gemm
.
MakeArgument
(
p_As
,
p_Bs
,
p_Ds
,
p_Cs
,
gemm_descs
,
a_element_op
,
b_element_op
,
cde_element_op
);
DeviceMem
gemm_desc_workspace
(
gemm
.
GetWorkSpaceSize
(
&
argument
));
hip_check_error
(
hipMemcpy
(
gemm_desc_workspace
.
GetDeviceBuffer
(),
grouped_gemm_kernel_args_
.
data
(),
gemm
.
GetWorkSpaceSize
(
&
argument
),
hipMemcpyHostToDevice
));
if
(
!
gemm
.
IsSupportedArgument
(
argument
))
{
throw
std
::
runtime_error
(
"wrong! device_gemm with the specified compilation parameters does "
"not support this GEMM problem"
);
}
gemm
.
SetDeviceKernelArgs
(
argument
,
gemm_desc_workspace
.
GetDeviceBuffer
());
invoker
.
Run
(
argument
,
StreamConfig
{
nullptr
,
false
});
bool
pass
=
true
;
if
(
config
.
do_verification
)
{
using
ReferenceGemmInstance
=
ck
::
tensor_operation
::
host
::
ReferenceGemm
<
ADataType
,
BDataType
,
EDataType
,
AccDataType
,
AElementOp
,
BElementOp
,
PassThrough
>
;
for
(
std
::
size_t
i
=
0
;
i
<
gemm_descs
.
size
();
i
++
)
{
c_tensors_device
[
i
]
->
FromDevice
(
c_device_tensors
[
i
].
mData
.
data
(),
c_device_tensors
[
i
].
mDesc
.
GetElementSize
()
*
sizeof
(
EDataType
));
auto
ref_gemm
=
ReferenceGemmInstance
{};
auto
ref_invoker
=
ref_gemm
.
MakeInvoker
();
auto
ref_argument
=
ref_gemm
.
MakeArgument
(
a_tensors
[
i
],
b_tensors
[
i
],
c_host_tensors
[
i
],
a_element_op
,
b_element_op
,
PassThrough
{});
ref_invoker
.
Run
(
ref_argument
);
for
(
int
m
=
0
;
m
<
problem_size
.
Ms
[
i
];
++
m
)
{
for
(
int
n
=
0
;
n
<
problem_size
.
Ns
[
i
];
++
n
)
{
cde_element_op
(
c_host_tensors
[
i
](
m
,
n
),
c_host_tensors
[
i
](
m
,
n
),
d0_tensors
[
i
](
m
,
n
));
}
}
pass
&=
ck
::
utils
::
check_err
(
c_device_tensors
[
i
],
c_host_tensors
[
i
]);
}
}
if
(
config
.
time_kernel
)
{
float
ave_time
=
invoker
.
Run
(
argument
,
StreamConfig
{
nullptr
,
config
.
time_kernel
});
float
tflops
=
static_cast
<
float
>
(
flop
)
/
1.E9
/
ave_time
;
float
gb_per_sec
=
num_btype
/
1.E6
/
ave_time
;
std
::
cout
<<
"Perf: "
<<
ave_time
<<
" ms, "
<<
tflops
<<
" TFlops, "
<<
gb_per_sec
<<
" GB/s, "
<<
gemm
.
GetTypeString
()
<<
std
::
endl
;
}
return
pass
;
}
// int main(int argc, char* argv[]) { return !run_grouped_gemm_example(argc, argv); }
int
main
(
int
argc
,
char
*
argv
[])
{
ProblemSize
problem_size
;
ExecutionConfig
config
;
problem_size
.
group_count
=
16
;
problem_size
.
Ms
=
{
167
,
183
,
177
,
181
,
153
,
139
,
156
,
173
,
163
,
150
,
204
,
184
,
168
,
156
,
168
,
148
};
for
(
int
i
=
0
;
i
<
problem_size
.
group_count
;
i
++
)
{
problem_size
.
Ns
.
push_back
(
768
);
problem_size
.
Ks
.
push_back
(
4608
);
problem_size
.
stride_As
.
push_back
(
problem_size
.
Ks
[
i
]);
problem_size
.
stride_Bs
.
push_back
(
problem_size
.
Ks
[
i
]);
problem_size
.
stride_Cs
.
push_back
(
problem_size
.
Ns
[
i
]);
}
if
(
argc
==
4
)
{
config
.
do_verification
=
std
::
stoi
(
argv
[
1
]);
config
.
init_method
=
std
::
stoi
(
argv
[
2
]);
config
.
time_kernel
=
std
::
stoi
(
argv
[
3
]);
}
else
{
printf
(
"arg1: verification (0=no, 1=yes)
\n
"
);
printf
(
"arg2: initialization (0=no init, 1=integer value, 2=decimal value)
\n
"
);
printf
(
"arg3: time kernel (0=n0, 1=yes)
\n
"
);
exit
(
0
);
}
return
!
run_grouped_gemm
(
problem_size
,
config
);
}
example/15_grouped_gemm/grouped_gemm_xdl_fixed_nk_fp16.cpp
View file @
5d61cd96
...
@@ -34,7 +34,7 @@ using PassThrough = ck::tensor_operation::element_wise::PassThrough;
...
@@ -34,7 +34,7 @@ using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using
ADataType
=
F16
;
using
ADataType
=
F16
;
using
BDataType
=
F16
;
using
BDataType
=
F16
;
using
AccDataType
=
F32
;
using
AccDataType
=
F32
;
using
CShuffleDataType
=
F
16
;
using
CShuffleDataType
=
F
32
;
using
DsDataType
=
ck
::
Tuple
<>
;
using
DsDataType
=
ck
::
Tuple
<>
;
using
EDataType
=
F16
;
using
EDataType
=
F16
;
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_fixed_nk.hpp
View file @
5d61cd96
...
@@ -29,6 +29,7 @@ template <typename GridwiseGemm,
...
@@ -29,6 +29,7 @@ template <typename GridwiseGemm,
typename
BLayout
,
typename
BLayout
,
typename
DsLayout
,
typename
DsLayout
,
typename
ELayout
,
typename
ELayout
,
typename
DsDataType
,
typename
Block2ETileMap
,
typename
Block2ETileMap
,
typename
GroupedGemmBlock2ETileMap
,
typename
GroupedGemmBlock2ETileMap
,
typename
AElementwiseOperation
,
typename
AElementwiseOperation
,
...
@@ -108,18 +109,6 @@ __global__ void
...
@@ -108,18 +109,6 @@ __global__ void
const
auto
StrideDs
=
gemm_desc_ptr
[
group_id
].
StrideDs
;
const
auto
StrideDs
=
gemm_desc_ptr
[
group_id
].
StrideDs
;
const
auto
StrideE
=
gemm_desc_ptr
[
group_id
].
StrideE
;
const
auto
StrideE
=
gemm_desc_ptr
[
group_id
].
StrideE
;
#if 0
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
using ALayout = Row;
using BLayout = Col;
using DsLayout = ck::Tuple<>;
using ELayout = Row;
#endif
using
DsDataType
=
ck
::
Tuple
<>
;
const
auto
e_grid_desc_m_n
=
const
auto
e_grid_desc_m_n
=
GridwiseGemm
::
template
MakeEGridDescriptor_M_N
<
ELayout
,
GemmSpec
>(
M
,
N
,
StrideE
);
GridwiseGemm
::
template
MakeEGridDescriptor_M_N
<
ELayout
,
GemmSpec
>(
M
,
N
,
StrideE
);
...
@@ -127,7 +116,7 @@ __global__ void
...
@@ -127,7 +116,7 @@ __global__ void
const
auto
local_b2e_tile_map
=
Block2ETileMap
{
e_grid_desc_m_n
};
const
auto
local_b2e_tile_map
=
Block2ETileMap
{
e_grid_desc_m_n
};
constexpr
auto
NumDTensor
=
0
;
constexpr
auto
NumDTensor
=
DsDataType
::
Size
()
;
using
DsGridPointer
=
decltype
(
GridwiseGemm
::
MakeDsGridPointer
());
using
DsGridPointer
=
decltype
(
GridwiseGemm
::
MakeDsGridPointer
());
...
@@ -580,10 +569,9 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
...
@@ -580,10 +569,9 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
throw
std
::
runtime_error
(
"wrong! group_count_ != p_Bs || 0 != p_Bs.size"
);
throw
std
::
runtime_error
(
"wrong! group_count_ != p_Bs || 0 != p_Bs.size"
);
}
}
if
(
!
(
group_count_
==
ck
::
type_convert
<
ck
::
index_t
>
(
p_Ds
.
size
())
||
if
(
!
(
group_count_
==
ck
::
type_convert
<
ck
::
index_t
>
(
p_Ds
.
size
())
||
NumDTensor
==
0
))
0
==
ck
::
type_convert
<
ck
::
index_t
>
(
p_Ds
.
size
())))
{
{
throw
std
::
runtime_error
(
"wrong! group_count_ != p_Ds
|| 0 != p_Ds.size
"
);
throw
std
::
runtime_error
(
"wrong! group_count_ != p_Ds"
);
}
}
if
(
!
(
group_count_
==
ck
::
type_convert
<
ck
::
index_t
>
(
p_Es
.
size
())))
if
(
!
(
group_count_
==
ck
::
type_convert
<
ck
::
index_t
>
(
p_Es
.
size
())))
...
@@ -648,11 +636,17 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
...
@@ -648,11 +636,17 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
const
index_t
grid_size_grp
=
local_b2c_tile_map
.
CalculateGridSize
(
e_grid_desc_m_n
);
const
index_t
grid_size_grp
=
local_b2c_tile_map
.
CalculateGridSize
(
e_grid_desc_m_n
);
std
::
cout
<<
"grp id: "
<<
group_id
<<
" grid_size: "
<<
grid_size_grp
<<
std
::
endl
;
// std::cout << "grp id: " << group_id << " grid_size: " << grid_size_grp <<
// std::endl;
const
index_t
BlockStart
=
grid_size_
;
const
index_t
BlockStart
=
grid_size_
;
const
index_t
BlockEnd
=
grid_size_
+
grid_size_grp
;
const
index_t
BlockEnd
=
grid_size_
+
grid_size_grp
;
if
(
group_id
*
grid_size_grp
!=
grid_size_
)
{
throw
std
::
runtime_error
(
"wrong! grid_size_grp is not identical!"
);
}
grid_size_
+=
grid_size_grp
;
grid_size_
+=
grid_size_grp
;
if
(
GridwiseGemm
::
CheckValidity
(
a_grid_desc_m_k
,
if
(
GridwiseGemm
::
CheckValidity
(
a_grid_desc_m_k
,
...
@@ -754,6 +748,14 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
...
@@ -754,6 +748,14 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
<<
arg
.
gemm_desc_kernel_arg_
[
i
].
b_grid_desc_bk0_n_bk1_
.
GetLength
(
I2
)
<<
arg
.
gemm_desc_kernel_arg_
[
i
].
b_grid_desc_bk0_n_bk1_
.
GetLength
(
I2
)
<<
"}"
;
<<
"}"
;
static_for
<
0
,
NumDTensor
,
1
>
{}([
&
](
auto
j
)
{
std
::
cout
<<
", arg.d"
<<
i
<<
"_grid_desc_m_n_{"
<<
arg
.
gemm_desc_kernel_arg_
[
i
].
ds_grid_desc_m_n_
[
j
].
GetLength
(
I0
)
<<
", "
<<
arg
.
gemm_desc_kernel_arg_
[
i
].
ds_grid_desc_m_n_
[
j
].
GetLength
(
I1
)
<<
"}"
;
});
std
::
cout
<<
", arg.e_grid_desc_m_n_{ "
std
::
cout
<<
", arg.e_grid_desc_m_n_{ "
<<
arg
.
gemm_desc_kernel_arg_
[
i
].
e_grid_desc_m_n_
.
GetLength
(
I0
)
<<
", "
<<
arg
.
gemm_desc_kernel_arg_
[
i
].
e_grid_desc_m_n_
.
GetLength
(
I0
)
<<
", "
<<
arg
.
gemm_desc_kernel_arg_
[
i
].
e_grid_desc_m_n_
.
GetLength
(
I1
)
<<
"}"
<<
arg
.
gemm_desc_kernel_arg_
[
i
].
e_grid_desc_m_n_
.
GetLength
(
I1
)
<<
"}"
...
@@ -805,6 +807,7 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
...
@@ -805,6 +807,7 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
BLayout
,
BLayout
,
DsLayout
,
DsLayout
,
ELayout
,
ELayout
,
DsDataType
,
Block2ETileMap
,
Block2ETileMap
,
GroupedGemmBlock2ETileMap
,
GroupedGemmBlock2ETileMap
,
AElementwiseOperation
,
AElementwiseOperation
,
...
...
include/ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp
View file @
5d61cd96
...
@@ -10,72 +10,6 @@ namespace ck {
...
@@ -10,72 +10,6 @@ namespace ck {
namespace
tensor_operation
{
namespace
tensor_operation
{
namespace
element_wise
{
namespace
element_wise
{
struct
Add
{
template
<
typename
Y
,
typename
X0
,
typename
X1
>
__host__
__device__
constexpr
void
operator
()(
Y
&
y
,
const
X0
&
x0
,
const
X1
&
x1
)
const
;
template
<
>
__host__
__device__
constexpr
void
operator
()
<
float
>
(
float
&
y
,
const
float
&
x0
,
const
float
&
x1
)
const
{
y
=
x0
+
x1
;
};
template
<
>
__host__
__device__
constexpr
void
operator
()
<
double
>
(
double
&
y
,
const
double
&
x0
,
const
double
&
x1
)
const
{
y
=
x0
+
x1
;
};
template
<
>
__host__
__device__
constexpr
void
operator
()
<
float
>
(
float
&
y
,
const
float
&
x0
,
const
half_t
&
x1
)
const
{
y
=
x0
+
type_convert
<
half_t
>
(
x1
);
};
template
<
>
__host__
__device__
constexpr
void
operator
()
<
half_t
>
(
half_t
&
y
,
const
float
&
x0
,
const
half_t
&
x1
)
const
{
y
=
type_convert
<
half_t
>
(
x0
)
+
x1
;
};
template
<
>
__host__
__device__
constexpr
void
operator
()
<
half_t
>
(
half_t
&
y
,
const
half_t
&
x0
,
const
half_t
&
x1
)
const
{
y
=
x0
+
x1
;
};
template
<
>
__host__
__device__
constexpr
void
operator
()
<
float
>
(
float
&
y
,
const
float
&
x0
,
const
bhalf_t
&
x1
)
const
{
const
float
x1_tmp
=
ck
::
type_convert
<
float
>
(
x1
);
y
=
x0
+
x1_tmp
;
}
template
<
>
__host__
__device__
constexpr
void
operator
()
<
bhalf_t
>
(
bhalf_t
&
y
,
const
bhalf_t
&
x0
,
const
bhalf_t
&
x1
)
const
{
const
float
x1_tmp
=
ck
::
type_convert
<
float
>
(
x0
);
const
float
x2_tmp
=
ck
::
type_convert
<
float
>
(
x1
);
const
float
y_tmp
=
x1_tmp
+
x2_tmp
;
y
=
ck
::
type_convert
<
bhalf_t
>
(
y_tmp
);
}
template
<
>
__host__
__device__
constexpr
void
operator
()
<
int8_t
>
(
int8_t
&
y
,
const
int8_t
&
x0
,
const
int8_t
&
x1
)
const
{
y
=
x0
+
x1
;
};
};
struct
ScaleAdd
struct
ScaleAdd
{
{
__host__
__device__
ScaleAdd
(
float
scale
)
:
scale_
(
scale
)
{}
__host__
__device__
ScaleAdd
(
float
scale
)
:
scale_
(
scale
)
{}
...
...
include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp
View file @
5d61cd96
...
@@ -51,6 +51,12 @@ struct PassThrough
...
@@ -51,6 +51,12 @@ struct PassThrough
y
=
x
;
y
=
x
;
}
}
template
<
>
__host__
__device__
void
operator
()
<
half_t
,
float
>
(
half_t
&
y
,
const
float
&
x
)
const
{
y
=
type_convert
<
half_t
>
(
x
);
}
template
<
>
template
<
>
__host__
__device__
void
operator
()
<
bhalf_t
,
float
>
(
bhalf_t
&
y
,
const
float
&
x
)
const
__host__
__device__
void
operator
()
<
bhalf_t
,
float
>
(
bhalf_t
&
y
,
const
float
&
x
)
const
{
{
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment