Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
6aaa77c1
"...composable_kernel.git" did not exist on "f63ca8e88150b428cb9f3be15ae472e8e38dd303"
Commit
6aaa77c1
authored
Oct 02, 2023
by
Jing Zhang
Browse files
finished an example
parent
7e734a03
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
131 additions
and
113 deletions
+131
-113
example/01_gemm/gemm_xdl_input_i16_comp_i8_scale_ab.cpp
example/01_gemm/gemm_xdl_input_i16_comp_i8_scale_ab.cpp
+15
-8
example/12_reduce/reduce_blockwise_two_call.cpp
example/12_reduce/reduce_blockwise_two_call.cpp
+3
-2
example/12_reduce/reduce_blockwise_two_call_amax.cpp
example/12_reduce/reduce_blockwise_two_call_amax.cpp
+12
-33
include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_scale_ab_xdl_cshuffle.hpp
...ice/impl/device_gemm_multiple_d_scale_ab_xdl_cshuffle.hpp
+101
-70
No files found.
example/01_gemm/gemm_xdl_input_i16_comp_i8_scale_ab.cpp
View file @
6aaa77c1
...
@@ -44,10 +44,10 @@ struct i32_to_i8
...
@@ -44,10 +44,10 @@ struct i32_to_i8
{
{
__host__
__device__
void
operator
()(
I8
&
y
,
const
I32
&
x
)
const
__host__
__device__
void
operator
()(
I8
&
y
,
const
I32
&
x
)
const
{
{
y
=
ck
::
type_convert
<
I8
>
(
x
)
*
scale
;
y
=
ck
::
type_convert
<
I8
>
(
ck
::
type_convert
<
float
>
(
x
)
*
reduced_amex_
scale
)
;
}
}
float
scale
=
1.0
;
float
reduced_amex_
scale
=
1.0
;
};
};
using
AElementOp
=
i32_to_i8
;
using
AElementOp
=
i32_to_i8
;
...
@@ -175,12 +175,15 @@ int main(int argc, char* argv[])
...
@@ -175,12 +175,15 @@ int main(int argc, char* argv[])
std
::
cout
<<
"b_k_n: "
<<
b_k_n
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"b_k_n: "
<<
b_k_n
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"e_m_n: "
<<
e_m_n_host_result
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"e_m_n: "
<<
e_m_n_host_result
.
mDesc
<<
std
::
endl
;
ADataType
amax
=
5
;
BDataType
bmax
=
8
;
switch
(
init_method
)
switch
(
init_method
)
{
{
case
0
:
break
;
case
0
:
break
;
case
1
:
case
1
:
a_m_k
.
GenerateTensorValue
(
GeneratorTensor_2
<
ADataType
>
{
-
5
,
5
});
a_m_k
.
GenerateTensorValue
(
GeneratorTensor_2
<
ADataType
>
{
-
amax
,
amax
});
b_k_n
.
GenerateTensorValue
(
GeneratorTensor_2
<
BDataType
>
{
-
5
,
5
});
b_k_n
.
GenerateTensorValue
(
GeneratorTensor_2
<
BDataType
>
{
-
bmax
,
bmax
});
break
;
break
;
default:
default:
a_m_k
.
GenerateTensorValue
(
GeneratorTensor_3
<
ADataType
>
{
0.0
,
1.0
});
a_m_k
.
GenerateTensorValue
(
GeneratorTensor_3
<
ADataType
>
{
0.0
,
1.0
});
...
@@ -195,8 +198,8 @@ int main(int argc, char* argv[])
...
@@ -195,8 +198,8 @@ int main(int argc, char* argv[])
b_device_buf
.
ToDevice
(
b_k_n
.
mData
.
data
());
b_device_buf
.
ToDevice
(
b_k_n
.
mData
.
data
());
e_device_buf
.
ToDevice
(
e_m_n_device_result
.
mData
.
data
());
e_device_buf
.
ToDevice
(
e_m_n_device_result
.
mData
.
data
());
auto
a_element_op
=
AElementOp
{
0.2
};
auto
a_element_op
=
AElementOp
{};
auto
b_element_op
=
BElementOp
{
0.2
};
auto
b_element_op
=
BElementOp
{};
auto
cde_element_op
=
CDEElementOp
{};
auto
cde_element_op
=
CDEElementOp
{};
// do GEMM
// do GEMM
...
@@ -254,8 +257,12 @@ int main(int argc, char* argv[])
...
@@ -254,8 +257,12 @@ int main(int argc, char* argv[])
auto
ref_gemm
=
ReferenceGemmInstance
{};
auto
ref_gemm
=
ReferenceGemmInstance
{};
auto
ref_invoker
=
ref_gemm
.
MakeInvoker
();
auto
ref_invoker
=
ref_gemm
.
MakeInvoker
();
auto
ref_argument
=
auto
ref_argument
=
ref_gemm
.
MakeArgument
(
a_m_k
,
ref_gemm
.
MakeArgument
(
a_m_k
,
b_k_n
,
c_m_n
,
a_element_op
,
b_element_op
,
PassThrough
{});
b_k_n
,
c_m_n
,
AElementOp
{
static_cast
<
float
>
(
1.0
)
/
amax
},
BElementOp
{
static_cast
<
float
>
(
1.0
)
/
bmax
},
PassThrough
{});
ref_invoker
.
Run
(
ref_argument
);
ref_invoker
.
Run
(
ref_argument
);
...
...
example/12_reduce/reduce_blockwise_two_call.cpp
View file @
6aaa77c1
...
@@ -265,8 +265,9 @@ int main(int argc, char* argv[])
...
@@ -265,8 +265,9 @@ int main(int argc, char* argv[])
if
(
!
reduce_1
.
IsSupportedArgument
(
argument_ptr_1
.
get
()))
if
(
!
reduce_1
.
IsSupportedArgument
(
argument_ptr_1
.
get
()))
{
{
std
::
cout
<<
"The runtime parameters seems supported by the DeviceReduce instance, exiting!"
std
::
cout
<<
std
::
endl
;
<<
"The runtime parameters seems not supported by the DeviceReduce instance, exiting!"
<<
std
::
endl
;
};
};
auto
invoker_ptr_1
=
reduce_1
.
MakeInvokerPointer
();
auto
invoker_ptr_1
=
reduce_1
.
MakeInvokerPointer
();
...
...
example/12_reduce/reduce_blockwise_two_call_amax.cpp
View file @
6aaa77c1
...
@@ -73,8 +73,8 @@ using DeviceReduceInstance_2 = DeviceReduceMultiBlock<InOutDataType,
...
@@ -73,8 +73,8 @@ using DeviceReduceInstance_2 = DeviceReduceMultiBlock<InOutDataType,
OutputIndex
,
OutputIndex
,
false
,
// HaveIndexInputIfOutputIndex
false
,
// HaveIndexInputIfOutputIndex
256
,
256
,
128
,
32
,
2
,
8
,
1
,
1
,
1
,
1
,
1
,
// vector dim
1
,
// vector dim
...
@@ -83,14 +83,12 @@ using DeviceReduceInstance_2 = DeviceReduceMultiBlock<InOutDataType,
...
@@ -83,14 +83,12 @@ using DeviceReduceInstance_2 = DeviceReduceMultiBlock<InOutDataType,
static
bool
do_verify
;
static
bool
do_verify
;
static
int
init_method
;
static
int
init_method
;
static
float
alpha
;
static
float
beta
;
static
bool
time_kernel
;
static
bool
time_kernel
;
int
main
(
int
argc
,
char
*
argv
[])
int
main
(
int
argc
,
char
*
argv
[])
{
{
// used by the device reduction
// used by the device reduction
const
std
::
array
<
int
,
1
>
reduceDims_1
=
{
0
};
const
std
::
array
<
int
,
1
>
reduceDims_1
=
{
1
};
const
std
::
array
<
int
,
1
>
reduceDims_2
=
{
0
};
const
std
::
array
<
int
,
1
>
reduceDims_2
=
{
0
};
// used by the host reduction
// used by the host reduction
...
@@ -126,9 +124,6 @@ int main(int argc, char* argv[])
...
@@ -126,9 +124,6 @@ int main(int argc, char* argv[])
throw
std
::
runtime_error
(
ostr
.
str
());
throw
std
::
runtime_error
(
ostr
.
str
());
};
};
alpha
=
1.0
f
;
beta
=
0.0
f
;
Tensor
<
InOutDataType
>
in_1
(
inLengths_1
);
Tensor
<
InOutDataType
>
in_1
(
inLengths_1
);
Tensor
<
InOutDataType
>
out_ref
(
outLengths
);
Tensor
<
InOutDataType
>
out_ref
(
outLengths
);
...
@@ -149,26 +144,12 @@ int main(int argc, char* argv[])
...
@@ -149,26 +144,12 @@ int main(int argc, char* argv[])
switch
(
init_method
)
switch
(
init_method
)
{
{
case
0
:
break
;
case
0
:
break
;
case
1
:
case
1
:
in_1
.
GenerateTensorValue
(
GeneratorTensor_1
<
InOutDataType
>
{
1
},
num_thread
);
break
;
in_1
.
GenerateTensorValue
(
GeneratorTensor_1
<
InOutDataType
>
{
1
},
num_thread
);
if
(
beta
!=
0.0
f
)
out_ref
.
GenerateTensorValue
(
GeneratorTensor_1
<
InOutDataType
>
{
1
},
num_thread
);
break
;
case
2
:
case
2
:
in_1
.
GenerateTensorValue
(
GeneratorTensor_2
<
InOutDataType
>
{
-
5
,
5
},
num_thread
);
in_1
.
GenerateTensorValue
(
GeneratorTensor_2
<
InOutDataType
>
{
-
5
,
5
},
num_thread
);
if
(
beta
!=
0.0
f
)
out_ref
.
GenerateTensorValue
(
GeneratorTensor_2
<
InOutDataType
>
{
-
5
,
5
},
num_thread
);
break
;
break
;
default:
default:
in_1
.
GenerateTensorValue
(
GeneratorTensor_3
<
InOutDataType
>
{
-
5.0
,
5.0
},
num_thread
);
in_1
.
GenerateTensorValue
(
GeneratorTensor_3
<
InOutDataType
>
{
-
5.0
,
5.0
},
num_thread
);
if
(
beta
!=
0.0
f
)
out_ref
.
GenerateTensorValue
(
GeneratorTensor_3
<
InOutDataType
>
{
-
5.0
,
5.0
},
num_thread
);
}
}
if
(
beta
!=
0.0
f
)
for
(
size_t
i
=
0
;
i
<
out_ref
.
mDesc
.
GetElementSpaceSize
();
i
++
)
out
.
mData
[
i
]
=
out_ref
.
mData
[
i
];
};
};
DeviceMem
in_1_dev
(
sizeof
(
InOutDataType
)
*
in_1
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
in_1_dev
(
sizeof
(
InOutDataType
)
*
in_1
.
mDesc
.
GetElementSpaceSize
());
...
@@ -177,9 +158,6 @@ int main(int argc, char* argv[])
...
@@ -177,9 +158,6 @@ int main(int argc, char* argv[])
in_1_dev
.
ToDevice
(
in_1
.
mData
.
data
());
in_1_dev
.
ToDevice
(
in_1
.
mData
.
data
());
if
(
beta
!=
0.0
f
)
out_dev
.
ToDevice
(
out
.
mData
.
data
());
InElementwiseOperation
in_elementwise_op
;
InElementwiseOperation
in_elementwise_op
;
AccElementwiseOperation
acc_elementwise_op
;
AccElementwiseOperation
acc_elementwise_op
;
...
@@ -222,8 +200,8 @@ int main(int argc, char* argv[])
...
@@ -222,8 +200,8 @@ int main(int argc, char* argv[])
arrOutLengths
,
arrOutLengths
,
arrOutStrides
,
arrOutStrides
,
reduceDims
,
reduceDims
,
static_cast
<
double
>
(
alpha
)
,
1.0
,
static_cast
<
double
>
(
beta
)
,
0.0
,
in_1
.
mData
.
data
(),
in_1
.
mData
.
data
(),
nullptr
,
nullptr
,
out_ref
.
mData
.
data
(),
out_ref
.
mData
.
data
(),
...
@@ -261,8 +239,9 @@ int main(int argc, char* argv[])
...
@@ -261,8 +239,9 @@ int main(int argc, char* argv[])
if
(
!
reduce_1
.
IsSupportedArgument
(
argument_ptr_1
.
get
()))
if
(
!
reduce_1
.
IsSupportedArgument
(
argument_ptr_1
.
get
()))
{
{
std
::
cout
<<
"The runtime parameters seems supported by the DeviceReduce instance, exiting!"
std
::
cout
<<
std
::
endl
;
<<
"The runtime parameters seems not supported by the DeviceReduce instance, exiting!"
<<
std
::
endl
;
};
};
auto
invoker_ptr_1
=
reduce_1
.
MakeInvokerPointer
();
auto
invoker_ptr_1
=
reduce_1
.
MakeInvokerPointer
();
...
@@ -274,8 +253,8 @@ int main(int argc, char* argv[])
...
@@ -274,8 +253,8 @@ int main(int argc, char* argv[])
arrOutLengths
,
arrOutLengths
,
arrOutStrides
,
arrOutStrides
,
reduceDims_2
,
reduceDims_2
,
static_cast
<
double
>
(
alpha
)
,
1.0
,
static_cast
<
double
>
(
beta
)
,
0.0
,
in_2_dev
.
GetDeviceBuffer
(),
in_2_dev
.
GetDeviceBuffer
(),
nullptr
,
nullptr
,
out_dev
.
GetDeviceBuffer
(),
out_dev
.
GetDeviceBuffer
(),
...
...
include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_scale_ab_xdl_cshuffle.hpp
View file @
6aaa77c1
...
@@ -20,6 +20,8 @@
...
@@ -20,6 +20,8 @@
#include "ck/tensor_operation/gpu/device/impl/device_reduce_multiblock.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_reduce_multiblock.hpp"
#include "ck/tensor_operation/gpu/device/reduction_operator_mapping.hpp"
#include "ck/tensor_operation/gpu/device/reduction_operator_mapping.hpp"
#include "ck/utility/is_detected.hpp"
namespace
ck
{
namespace
ck
{
template
<
typename
GridwiseGemm
,
template
<
typename
GridwiseGemm
,
...
@@ -164,6 +166,8 @@ struct DeviceGemmMultipleDScaleAB_Xdl_CShuffle
...
@@ -164,6 +166,8 @@ struct DeviceGemmMultipleDScaleAB_Xdl_CShuffle
{
{
using
DeviceOp
=
DeviceGemmMultipleDScaleAB_Xdl_CShuffle
;
using
DeviceOp
=
DeviceGemmMultipleDScaleAB_Xdl_CShuffle
;
using
RowMajor
=
tensor_layout
::
gemm
::
RowMajor
;
static
constexpr
index_t
NumDTensor
=
DsDataType
::
Size
();
static
constexpr
index_t
NumDTensor
=
DsDataType
::
Size
();
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I0
=
Number
<
0
>
{};
...
@@ -177,7 +181,7 @@ struct DeviceGemmMultipleDScaleAB_Xdl_CShuffle
...
@@ -177,7 +181,7 @@ struct DeviceGemmMultipleDScaleAB_Xdl_CShuffle
static
auto
MakeAGridDescriptor_M_K
(
index_t
MRaw
,
index_t
KRaw
,
index_t
StrideA
)
static
auto
MakeAGridDescriptor_M_K
(
index_t
MRaw
,
index_t
KRaw
,
index_t
StrideA
)
{
{
const
auto
a_grid_desc_mraw_kraw
=
[
&
]()
{
const
auto
a_grid_desc_mraw_kraw
=
[
&
]()
{
if
constexpr
(
is_same_v
<
tensor_layout
::
gemm
::
RowMajor
,
ALayout
>
)
if
constexpr
(
is_same_v
<
RowMajor
,
ALayout
>
)
{
{
return
make_naive_tensor_descriptor
(
make_tuple
(
MRaw
,
KRaw
),
return
make_naive_tensor_descriptor
(
make_tuple
(
MRaw
,
KRaw
),
make_tuple
(
StrideA
,
I1
));
make_tuple
(
StrideA
,
I1
));
...
@@ -195,7 +199,7 @@ struct DeviceGemmMultipleDScaleAB_Xdl_CShuffle
...
@@ -195,7 +199,7 @@ struct DeviceGemmMultipleDScaleAB_Xdl_CShuffle
static
auto
MakeBGridDescriptor_N_K
(
index_t
KRaw
,
index_t
NRaw
,
index_t
StrideB
)
static
auto
MakeBGridDescriptor_N_K
(
index_t
KRaw
,
index_t
NRaw
,
index_t
StrideB
)
{
{
const
auto
b_grid_desc_nraw_kraw
=
[
&
]()
{
const
auto
b_grid_desc_nraw_kraw
=
[
&
]()
{
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
RowMajor
,
BLayout
>::
value
)
if
constexpr
(
is_same
<
RowMajor
,
BLayout
>::
value
)
{
{
return
make_naive_tensor_descriptor
(
make_tuple
(
NRaw
,
KRaw
),
return
make_naive_tensor_descriptor
(
make_tuple
(
NRaw
,
KRaw
),
make_tuple
(
I1
,
StrideB
));
make_tuple
(
I1
,
StrideB
));
...
@@ -214,7 +218,7 @@ struct DeviceGemmMultipleDScaleAB_Xdl_CShuffle
...
@@ -214,7 +218,7 @@ struct DeviceGemmMultipleDScaleAB_Xdl_CShuffle
static
auto
MakeEGridDescriptor_M_N
(
index_t
MRaw
,
index_t
NRaw
,
index_t
StrideE
)
static
auto
MakeEGridDescriptor_M_N
(
index_t
MRaw
,
index_t
NRaw
,
index_t
StrideE
)
{
{
const
auto
e_grid_desc_mraw_nraw
=
[
&
]()
{
const
auto
e_grid_desc_mraw_nraw
=
[
&
]()
{
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
RowMajor
,
ELay
>::
value
)
if
constexpr
(
is_same
<
RowMajor
,
ELay
>::
value
)
{
{
return
make_naive_tensor_descriptor
(
make_tuple
(
MRaw
,
NRaw
),
return
make_naive_tensor_descriptor
(
make_tuple
(
MRaw
,
NRaw
),
make_tuple
(
StrideE
,
I1
));
make_tuple
(
StrideE
,
I1
));
...
@@ -425,7 +429,7 @@ struct DeviceGemmMultipleDScaleAB_Xdl_CShuffle
...
@@ -425,7 +429,7 @@ struct DeviceGemmMultipleDScaleAB_Xdl_CShuffle
index_t
KRaw_
;
index_t
KRaw_
;
};
};
template
<
typename
InOutDataType
>
template
<
typename
InOutDataType
,
typename
Layout
>
struct
Reduce2D
struct
Reduce2D
{
{
static
constexpr
ReduceTensorOp
ReduceOpId
=
ReduceTensorOp
::
AMAX
;
static
constexpr
ReduceTensorOp
ReduceOpId
=
ReduceTensorOp
::
AMAX
;
...
@@ -440,26 +444,27 @@ struct DeviceGemmMultipleDScaleAB_Xdl_CShuffle
...
@@ -440,26 +444,27 @@ struct DeviceGemmMultipleDScaleAB_Xdl_CShuffle
using
PassThroughOp
=
tensor_operation
::
element_wise
::
PassThrough
;
using
PassThroughOp
=
tensor_operation
::
element_wise
::
PassThrough
;
using
DeviceReduceInstance_1
=
DeviceReduceMultiBlock
<
InOutDataType
,
using
DeviceReduceInstance_1
=
InOutDataType
,
DeviceReduceMultiBlock
<
InOutDataType
,
InOutDataType
,
InOutDataType
,
2
,
// Rank
InOutDataType
,
1
,
// NumReduceDim
2
,
// Rank
ReduceOperation
,
1
,
// NumReduceDim
InElementwiseOperation
,
ReduceOperation
,
PassThroughOp
,
InElementwiseOperation
,
InMemoryDataOperationEnum
::
Set
,
PassThroughOp
,
PropagateNan
,
InMemoryDataOperationEnum
::
Set
,
OutputIndex
,
PropagateNan
,
false
,
// HaveIndexInputIfOutputIndex
OutputIndex
,
256
,
false
,
// HaveIndexInputIfOutputIndex
32
,
256
,
8
,
32
,
1
,
8
,
1
,
1
,
1
,
// vector dim
1
,
1
,
is_same
<
RowMajor
,
Layout
>::
value
?
1
:
0
,
// vector dim
1
>
;
1
,
1
>
;
using
DeviceReduceInstance_2
=
DeviceReduceMultiBlock
<
InOutDataType
,
using
DeviceReduceInstance_2
=
DeviceReduceMultiBlock
<
InOutDataType
,
InOutDataType
,
InOutDataType
,
...
@@ -473,9 +478,9 @@ struct DeviceGemmMultipleDScaleAB_Xdl_CShuffle
...
@@ -473,9 +478,9 @@ struct DeviceGemmMultipleDScaleAB_Xdl_CShuffle
PropagateNan
,
PropagateNan
,
OutputIndex
,
OutputIndex
,
false
,
// HaveIndexInputIfOutputIndex
false
,
// HaveIndexInputIfOutputIndex
256
,
256
,
// BlockSize
128
,
32
,
// MThreadClusterSize
2
,
8
,
// KThreadClusterSize
1
,
1
,
1
,
1
,
1
,
// vector dim
1
,
// vector dim
...
@@ -493,7 +498,7 @@ struct DeviceGemmMultipleDScaleAB_Xdl_CShuffle
...
@@ -493,7 +498,7 @@ struct DeviceGemmMultipleDScaleAB_Xdl_CShuffle
const
std
::
array
<
int
,
1
>
reduceDims_1
=
{
arrInLengths_1
[
0
]
>
arrInLengths_1
[
1
]
?
0
:
1
};
const
std
::
array
<
int
,
1
>
reduceDims_1
=
{
arrInLengths_1
[
0
]
>
arrInLengths_1
[
1
]
?
0
:
1
};
const
std
::
array
<
int
,
1
>
reduceDims_2
=
{
0
};
const
std
::
array
<
int
,
1
>
reduceDims_2
=
{
0
};
std
::
array
<
index_t
,
1
>
arrInLengths_2
{
arrInLengths_1
[
reduceDims_1
[
0
]]};
std
::
array
<
index_t
,
1
>
arrInLengths_2
{
arrInLengths_1
[
!
reduceDims_1
[
0
]]};
std
::
array
<
index_t
,
1
>
arrInStrides_2
{
1
};
std
::
array
<
index_t
,
1
>
arrInStrides_2
{
1
};
std
::
array
<
index_t
,
1
>
arrOutLengths
{
1
};
std
::
array
<
index_t
,
1
>
arrOutLengths
{
1
};
...
@@ -520,9 +525,10 @@ struct DeviceGemmMultipleDScaleAB_Xdl_CShuffle
...
@@ -520,9 +525,10 @@ struct DeviceGemmMultipleDScaleAB_Xdl_CShuffle
if
(
!
reduce_1
.
IsSupportedArgument
(
argument_ptr_1
.
get
()))
if
(
!
reduce_1
.
IsSupportedArgument
(
argument_ptr_1
.
get
()))
{
{
std
::
cout
<<
"The runtime parameters seems supported by the DeviceReduce instance, "
std
::
cout
"exiting!"
<<
"The runtime parameters seems not supported by the DeviceReduce instance, "
<<
std
::
endl
;
"exiting!"
<<
std
::
endl
;
};
};
auto
invoker_ptr_1
=
reduce_1
.
MakeInvokerPointer
();
auto
invoker_ptr_1
=
reduce_1
.
MakeInvokerPointer
();
...
@@ -564,6 +570,9 @@ struct DeviceGemmMultipleDScaleAB_Xdl_CShuffle
...
@@ -564,6 +570,9 @@ struct DeviceGemmMultipleDScaleAB_Xdl_CShuffle
{
{
using
Argument
=
DeviceOp
::
Argument
;
using
Argument
=
DeviceOp
::
Argument
;
template
<
typename
T
>
using
has_reduced_amex_scale
=
decltype
(
std
::
declval
<
T
&>
().
reduced_amex_scale
);
float
Run
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
float
Run
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
{
{
if
(
!
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_m_k_
,
if
(
!
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_m_k_
,
...
@@ -575,45 +584,67 @@ struct DeviceGemmMultipleDScaleAB_Xdl_CShuffle
...
@@ -575,45 +584,67 @@ struct DeviceGemmMultipleDScaleAB_Xdl_CShuffle
throw
std
::
runtime_error
(
"wrong! GridwiseGemm has invalid setting"
);
throw
std
::
runtime_error
(
"wrong! GridwiseGemm has invalid setting"
);
}
}
using
RowMajor
=
tensor_layout
::
gemm
::
RowMajor
;
float
kern_time
=
0
;
float
kern_time
=
0
;
ADataType
amax_a
,
amax_b
;
AElementwiseOperation
a_element_op_
=
arg
.
a_element_op_
;
auto
reduce_a
=
Reduce2D
<
ADataType
>
{};
if
constexpr
(
is_detected
<
has_reduced_amex_scale
,
AElementwiseOperation
>::
value
)
kern_time
+=
reduce_a
.
Run
({
arg
.
MRaw_
,
arg
.
KRaw_
},
{
is_same
<
RowMajor
,
ALayout
>::
value
// A[M, K]
ADataType
amax_a
;
?
std
::
array
<
index_t
,
2
>
{
arg
.
KRaw_
,
I1
}
:
std
::
array
<
index_t
,
2
>
{
I1
,
arg
.
MRaw_
},
auto
reduce_a
=
Reduce2D
<
ADataType
,
ALayout
>
{};
arg
.
p_a_grid_
,
kern_time
+=
reduce_a
.
Run
({
arg
.
MRaw_
,
arg
.
KRaw_
},
arg
.
p_e_grid_
,
is_same
<
RowMajor
,
ALayout
>::
value
// A[M, K]
arg
.
p_e_grid_
,
?
std
::
array
<
index_t
,
2
>
{
arg
.
KRaw_
,
I1
}
stream_config
);
:
std
::
array
<
index_t
,
2
>
{
I1
,
arg
.
MRaw_
},
arg
.
p_a_grid_
,
hipGetErrorString
(
hipMemcpyWithStream
(
&
amax_a
,
arg
.
p_e_grid_
,
arg
.
p_e_grid_
,
arg
.
p_e_grid_
,
sizeof
(
ADataType
),
stream_config
);
hipMemcpyDeviceToHost
,
stream_config
.
stream_id_
));
hipGetErrorString
(
hipMemcpyWithStream
(
&
amax_a
,
arg
.
p_e_grid_
,
auto
reduce_b
=
Reduce2D
<
BDataType
>
{};
sizeof
(
ADataType
),
kern_time
+=
reduce_b
.
Run
({
arg
.
KRaw_
,
arg
.
NRaw_
},
hipMemcpyDeviceToHost
,
is_same
<
RowMajor
,
BLayout
>::
value
// B[K, N]
stream_config
.
stream_id_
));
?
std
::
array
<
index_t
,
2
>
{
arg
.
NRaw_
,
I1
}
:
std
::
array
<
index_t
,
2
>
{
I1
,
arg
.
KRaw_
},
static_assert
(
is_same
<
decltype
(
arg
.
a_element_op_
.
reduced_amex_scale
),
float
>::
value
,
arg
.
p_a_grid_
,
"scale is not float!"
);
arg
.
p_e_grid_
,
arg
.
p_e_grid_
,
a_element_op_
.
reduced_amex_scale
=
1.0
/
amax_a
;
stream_config
);
// std::cout << " amax_a: " << amax_a << std::endl;
hipGetErrorString
(
hipMemcpyWithStream
(
&
amax_b
,
}
arg
.
p_e_grid_
,
sizeof
(
ADataType
),
BElementwiseOperation
b_element_op_
=
arg
.
b_element_op_
;
hipMemcpyDeviceToHost
,
stream_config
.
stream_id_
));
if
constexpr
(
is_detected
<
has_reduced_amex_scale
,
BElementwiseOperation
>::
value
)
{
// std::cout << "amax_a: " << amax_a << " amax_b: " << amax_b << std::endl;
ADataType
amax_b
;
auto
reduce_b
=
Reduce2D
<
BDataType
,
BLayout
>
{};
kern_time
+=
reduce_b
.
Run
({
arg
.
KRaw_
,
arg
.
NRaw_
},
is_same
<
RowMajor
,
BLayout
>::
value
// B[K, N]
?
std
::
array
<
index_t
,
2
>
{
arg
.
NRaw_
,
I1
}
:
std
::
array
<
index_t
,
2
>
{
I1
,
arg
.
KRaw_
},
arg
.
p_b_grid_
,
arg
.
p_e_grid_
,
arg
.
p_e_grid_
,
stream_config
);
hipGetErrorString
(
hipMemcpyWithStream
(
&
amax_b
,
arg
.
p_e_grid_
,
sizeof
(
BDataType
),
hipMemcpyDeviceToHost
,
stream_config
.
stream_id_
));
static_assert
(
is_same
<
decltype
(
arg
.
b_element_op_
.
reduced_amex_scale
),
float
>::
value
,
"scale is not float!"
);
b_element_op_
.
reduced_amex_scale
=
1.0
/
amax_b
;
// std::cout << " amax_b: " << amax_b << std::endl;
}
const
index_t
grid_size
=
const
index_t
grid_size
=
arg
.
block_2_etile_map_
.
CalculateGridSize
(
arg
.
e_grid_desc_m_n_
);
arg
.
block_2_etile_map_
.
CalculateGridSize
(
arg
.
e_grid_desc_m_n_
);
...
@@ -646,8 +677,8 @@ struct DeviceGemmMultipleDScaleAB_Xdl_CShuffle
...
@@ -646,8 +677,8 @@ struct DeviceGemmMultipleDScaleAB_Xdl_CShuffle
arg
.
p_b_grid_
,
arg
.
p_b_grid_
,
arg
.
p_ds_grid_
,
arg
.
p_ds_grid_
,
arg
.
p_e_grid_
,
arg
.
p_e_grid_
,
arg
.
a_element_op_
,
a_element_op_
,
arg
.
b_element_op_
,
b_element_op_
,
arg
.
cde_element_op_
,
arg
.
cde_element_op_
,
arg
.
a_grid_desc_ak0_m_ak1_
,
arg
.
a_grid_desc_ak0_m_ak1_
,
arg
.
b_grid_desc_bk0_n_bk1_
,
arg
.
b_grid_desc_bk0_n_bk1_
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment