Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
05ab9105
Commit
05ab9105
authored
Oct 19, 2024
by
Jing Zhang
Browse files
fixed reference and host_tensor
parent
205e0365
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
92 additions
and
27 deletions
+92
-27
example/01_gemm/run_gemm_example_v2.inc
example/01_gemm/run_gemm_example_v2.inc
+33
-19
include/ck/utility/amd_xdlops.hpp
include/ck/utility/amd_xdlops.hpp
+2
-2
library/include/ck/library/reference_tensor_operation/cpu/reference_gemm.hpp
...library/reference_tensor_operation/cpu/reference_gemm.hpp
+11
-0
library/include/ck/library/utility/host_tensor.hpp
library/include/ck/library/utility/host_tensor.hpp
+46
-6
No files found.
example/01_gemm/run_gemm_example_v2.inc
View file @
05ab9105
...
@@ -158,6 +158,9 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
...
@@ -158,6 +158,9 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
b_k_n
.
GenerateTensorValue
(
GeneratorTensor_3
<
BDataType
>
{
-
0.5
,
0.5
});
b_k_n
.
GenerateTensorValue
(
GeneratorTensor_3
<
BDataType
>
{
-
0.5
,
0.5
});
}
}
b_k_n
(
0
,
0
)
=
0xaa
;
b_k_n
(
1
,
1
)
=
0xaa
;
Tensor
<
CDataType
>
c_m_n_host_result
(
f_host_tensor_descriptor
(
M
,
N
,
StrideC
,
CLayout
{}));
Tensor
<
CDataType
>
c_m_n_host_result
(
f_host_tensor_descriptor
(
M
,
N
,
StrideC
,
CLayout
{}));
Tensor
<
CDataType
>
c_m_n_device_result
(
f_host_tensor_descriptor
(
M
,
N
,
StrideC
,
CLayout
{}));
Tensor
<
CDataType
>
c_m_n_device_result
(
f_host_tensor_descriptor
(
M
,
N
,
StrideC
,
CLayout
{}));
...
@@ -207,31 +210,42 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
...
@@ -207,31 +210,42 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
bool
pass
=
true
;
bool
pass
=
true
;
if
(
config
.
do_verification
)
if
(
config
.
do_verification
)
{
{
//
auto ref_gemm = ReferenceGemmInstance{};
auto
ref_gemm
=
ReferenceGemmInstance
{};
//
auto ref_invoker = ref_gemm.MakeInvoker();
auto
ref_invoker
=
ref_gemm
.
MakeInvoker
();
//
auto ref_argument = ref_gemm.MakeArgument(
auto
ref_argument
=
ref_gemm
.
MakeArgument
(
//
a_m_k, b_k_n, c_m_n_host_result, PassThrough{}, PassThrough{}, PassThrough{});
a_m_k
,
b_k_n
,
c_m_n_host_result
,
PassThrough
{},
PassThrough
{},
PassThrough
{});
//
ref_invoker.Run(ref_argument);
ref_invoker
.
Run
(
ref_argument
);
ave_time
=
invoker
.
Run
(
argument
,
StreamConfig
{
nullptr
,
false
,
1
});
ave_time
=
invoker
.
Run
(
argument
,
StreamConfig
{
nullptr
,
false
,
1
});
c_m_n_device_buf
.
FromDevice
(
c_m_n_device_result
.
mData
.
data
());
c_m_n_device_buf
.
FromDevice
(
c_m_n_device_result
.
mData
.
data
());
//pass &= ck::utils::check_err(c_m_n_device_result,
pass
&=
ck
::
utils
::
check_err
(
c_m_n_device_result
,
// c_m_n_host_result,
c_m_n_host_result
,
// "Error: Incorrect results!",
"Error: Incorrect results!"
,
// get_rtol<CDataType>(),
get_rtol
<
CDataType
>
(),
// get_atol<CDataType>());
get_atol
<
CDataType
>
());
//for(int i = 0; i < M; i++)
std
::
cout
<<
"c_m_n_device_result: "
<<
std
::
endl
;
//{
for
(
int
i
=
0
;
i
<
M
;
i
++
)
// for(int j = 0; j < N; j++)
{
// {
for
(
int
j
=
0
;
j
<
N
;
j
++
)
// std::cout << ck::type_convert<float>(c_m_n_device_result(i, j)) << ",";
{
// }
std
::
cout
<<
ck
::
type_convert
<
float
>
(
c_m_n_device_result
(
i
,
j
))
<<
","
;
// std::cout << std::endl;
}
//}
std
::
cout
<<
std
::
endl
;
}
std
::
cout
<<
"c_m_n_host_result: "
<<
std
::
endl
;
for
(
int
i
=
0
;
i
<
M
;
i
++
)
{
for
(
int
j
=
0
;
j
<
N
;
j
++
)
{
std
::
cout
<<
ck
::
type_convert
<
float
>
(
c_m_n_host_result
(
i
,
j
))
<<
","
;
}
std
::
cout
<<
std
::
endl
;
}
}
}
if
(
config
.
time_kernel
)
if
(
config
.
time_kernel
)
...
...
include/ck/utility/amd_xdlops.hpp
View file @
05ab9105
...
@@ -157,8 +157,8 @@ struct intrin_mfma_f32_16x16x16f16<16, 16>
...
@@ -157,8 +157,8 @@ struct intrin_mfma_f32_16x16x16f16<16, 16>
template
<
class
FloatC
>
template
<
class
FloatC
>
__device__
static
void
Run
(
const
half4_t
&
reg_a
,
const
half4_t
&
reg_b
,
FloatC
&
reg_c
)
__device__
static
void
Run
(
const
half4_t
&
reg_a
,
const
half4_t
&
reg_b
,
FloatC
&
reg_c
)
{
{
//
reg_c.template AsType<float4_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x16f16(
reg_c
.
template
AsType
<
float4_t
>()(
Number
<
0
>
{})
=
__builtin_amdgcn_mfma_f32_16x16x16f16
(
//
reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<0>{}], 0, 0, 0);
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
float4_t
>()[
Number
<
0
>
{}],
0
,
0
,
0
);
}
}
};
};
...
...
library/include/ck/library/reference_tensor_operation/cpu/reference_gemm.hpp
View file @
05ab9105
...
@@ -84,6 +84,17 @@ struct ReferenceGemm : public device::BaseOperator
...
@@ -84,6 +84,17 @@ struct ReferenceGemm : public device::BaseOperator
{
{
ck
::
tensor_operation
::
element_wise
::
PassThrough
{}(
v_b
,
arg
.
b_k_n_
(
k
,
n
));
ck
::
tensor_operation
::
element_wise
::
PassThrough
{}(
v_b
,
arg
.
b_k_n_
(
k
,
n
));
}
}
else
if
constexpr
(
is_same_v
<
BDataType
,
pk_i4_t
>
)
{
pk_i4_t
i4x2
=
arg
.
b_k_n_
(
k
,
n
);
int8_t
i4
=
0
;
if
(
k
%
2
==
1
)
i4
=
(
i4x2
>>
0
)
&
0xf
;
else
i4
=
(
i4x2
>>
4
)
&
0xf
;
i4
=
i4
-
8
;
arg
.
b_element_op_
(
v_b
,
i4
);
}
else
else
{
{
arg
.
b_element_op_
(
v_b
,
arg
.
b_k_n_
(
k
,
n
));
arg
.
b_element_op_
(
v_b
,
arg
.
b_k_n_
(
k
,
n
));
...
...
library/include/ck/library/utility/host_tensor.hpp
View file @
05ab9105
...
@@ -322,7 +322,12 @@ struct Tensor
...
@@ -322,7 +322,12 @@ struct Tensor
std
::
size_t
GetElementSize
()
const
{
return
mDesc
.
GetElementSize
();
}
std
::
size_t
GetElementSize
()
const
{
return
mDesc
.
GetElementSize
();
}
std
::
size_t
GetElementSpaceSize
()
const
{
return
mDesc
.
GetElementSpaceSize
();
}
std
::
size_t
GetElementSpaceSize
()
const
{
if
constexpr
(
ck
::
is_same_v
<
T
,
ck
::
pk_i4_t
>
)
return
mDesc
.
GetElementSpaceSize
()
/
2
;
else
return
mDesc
.
GetElementSpaceSize
();
}
std
::
size_t
GetElementSpaceSizeInBytes
()
const
{
return
sizeof
(
T
)
*
GetElementSpaceSize
();
}
std
::
size_t
GetElementSpaceSizeInBytes
()
const
{
return
sizeof
(
T
)
*
GetElementSpaceSize
();
}
...
@@ -469,29 +474,64 @@ struct Tensor
...
@@ -469,29 +474,64 @@ struct Tensor
template
<
typename
...
Is
>
template
<
typename
...
Is
>
std
::
size_t
GetOffsetFromMultiIndex
(
Is
...
is
)
const
std
::
size_t
GetOffsetFromMultiIndex
(
Is
...
is
)
const
{
{
return
mDesc
.
GetOffsetFromMultiIndex
(
is
...);
if
constexpr
(
ck
::
is_same_v
<
T
,
ck
::
pk_i4_t
>
)
{
return
mDesc
.
GetOffsetFromMultiIndex
(
is
...)
/
2
;
}
else
{
return
mDesc
.
GetOffsetFromMultiIndex
(
is
...);
}
}
}
template
<
typename
...
Is
>
template
<
typename
...
Is
>
T
&
operator
()(
Is
...
is
)
T
&
operator
()(
Is
...
is
)
{
{
return
mData
[
mDesc
.
GetOffsetFromMultiIndex
(
is
...)];
if
constexpr
(
ck
::
is_same_v
<
T
,
ck
::
pk_i4_t
>
)
{
return
mData
[
mDesc
.
GetOffsetFromMultiIndex
(
is
...)
/
2
];
}
else
{
return
mData
[
mDesc
.
GetOffsetFromMultiIndex
(
is
...)];
}
}
}
template
<
typename
...
Is
>
template
<
typename
...
Is
>
const
T
&
operator
()(
Is
...
is
)
const
const
T
&
operator
()(
Is
...
is
)
const
{
{
return
mData
[
mDesc
.
GetOffsetFromMultiIndex
(
is
...)];
if
constexpr
(
ck
::
is_same_v
<
T
,
ck
::
pk_i4_t
>
)
{
return
mData
[
mDesc
.
GetOffsetFromMultiIndex
(
is
...)
/
2
];
}
else
{
return
mData
[
mDesc
.
GetOffsetFromMultiIndex
(
is
...)];
}
}
}
T
&
operator
()(
std
::
vector
<
std
::
size_t
>
idx
)
T
&
operator
()(
std
::
vector
<
std
::
size_t
>
idx
)
{
{
return
mData
[
mDesc
.
GetOffsetFromMultiIndex
(
idx
)];
if
constexpr
(
ck
::
is_same_v
<
T
,
ck
::
pk_i4_t
>
)
{
return
mData
[
mDesc
.
GetOffsetFromMultiIndex
(
idx
)
/
2
];
}
else
{
return
mData
[
mDesc
.
GetOffsetFromMultiIndex
(
idx
)];
}
}
}
const
T
&
operator
()(
std
::
vector
<
std
::
size_t
>
idx
)
const
const
T
&
operator
()(
std
::
vector
<
std
::
size_t
>
idx
)
const
{
{
return
mData
[
mDesc
.
GetOffsetFromMultiIndex
(
idx
)];
if
constexpr
(
ck
::
is_same_v
<
T
,
ck
::
pk_i4_t
>
)
{
return
mData
[
mDesc
.
GetOffsetFromMultiIndex
(
idx
)
/
2
];
}
else
{
return
mData
[
mDesc
.
GetOffsetFromMultiIndex
(
idx
)];
}
}
}
typename
Data
::
iterator
begin
()
{
return
mData
.
begin
();
}
typename
Data
::
iterator
begin
()
{
return
mData
.
begin
();
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment