Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
63bad606
Commit
63bad606
authored
Jan 21, 2021
by
Jing Zhang
Browse files
demo of removing array for A/B in xdlops
parent
494608ce
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
41 additions
and
17 deletions
+41
-17
composable_kernel/include/tensor_operation/xdlops_gemm.hpp
composable_kernel/include/tensor_operation/xdlops_gemm.hpp
+38
-14
composable_kernel/include/utility/amd_xdlops.hpp
composable_kernel/include/utility/amd_xdlops.hpp
+3
-3
No files found.
composable_kernel/include/tensor_operation/xdlops_gemm.hpp
View file @
63bad606
...
...
@@ -123,12 +123,9 @@ struct mfma_info<mfma_instr::mfma_f32_16x16x4xf32>
class
FloatA
,
class
FloatB
,
class
FloatC
>
__device__
FloatC
run
(
const
FloatA
*
a
,
const
FloatB
*
b
,
FloatC
reg_c
)
const
__device__
FloatC
run
(
const
FloatA
a
,
const
FloatB
b
,
FloatC
reg_c
)
const
{
const
auto
p_a
=
reinterpret_cast
<
const
float
*>
(
a
);
const
auto
p_b
=
reinterpret_cast
<
const
float
*>
(
b
);
return
intrin_mfma_f32_16x16x4f32
(
p_a
,
p_b
,
reg_c
);
return
intrin_mfma_f32_16x16x4f32
(
a
,
b
,
reg_c
);
}
};
...
...
@@ -708,6 +705,12 @@ struct XdlopsGemm_t
}
#endif
template
<
class
FloatAB
>
__device__
static
auto
lds_load
(
const
FloatAB
*
p_src
,
const
index_t
src_offset
)
{
return
p_src
[
src_offset
];
}
template
<
index_t
M
,
index_t
N
,
index_t
K
,
class
FloatA
,
class
FloatB
,
class
FloatC
>
__device__
FloatC
Run
(
const
FloatA
*
const
__restrict__
p_a_wave
,
const
FloatB
*
const
__restrict__
p_b_wave
,
...
...
@@ -727,6 +730,13 @@ struct XdlopsGemm_t
FloatA
a
[
K
*
MRepeats
];
FloatB
b
[
K
*
NRepeats
];
constexpr
index_t
data_size
=
sizeof
(
FloatA
)
/
sizeof
(
data_type
);
constexpr
index_t
a_reg_buff_size
=
K
*
MRepeats
*
data_size
;
constexpr
index_t
b_reg_buff_size
=
K
*
NRepeats
*
data_size
;
auto
reg_a
=
GetRegBuffer
<
data_type
,
a_reg_buff_size
>
();
auto
reg_b
=
GetRegBuffer
<
data_type
,
b_reg_buff_size
>
();
static_assert
(
sizeof
(
FloatA
)
%
(
sizeof
(
data_type
)
*
mfma_type
.
k_base
)
==
0
,
"wrong! FloatA is consistent with mfma"
);
...
...
@@ -769,25 +779,39 @@ struct XdlopsGemm_t
const
index_t
blk_id
=
laneId
/
mfma_type
.
num_threads_blk
;
const
index_t
blk_td
=
laneId
%
mfma_type
.
num_threads_blk
;
#if 0
// load into registers
for(index_t k_i = 0; k_i < K; k_i += mfma_type.num_input_blks)
{
a[k_i] = p_a_wave[(k_i + blk_id) * M + blk_td];
b[k_i] = p_b_wave[(k_i + blk_id) * N + blk_td];
}
#if CK_WORKAROUND_SWDEV_229564
#pragma unroll
#endif
for
(
index_t
k_i
=
0
;
k_i
<
K
;
k_i
+=
mfma_type
.
num_input_blks
)
{
for
(
index_t
i
=
0
;
i
<
KRepeats
;
++
i
)
static_for
<
0
,
K
,
mfma_type
.
num_input_blks
>
{}([
&
](
auto
k_i
)
{
index_t
a_offset
=
(
k_i
+
blk_id
)
*
M
+
blk_td
;
reg_a
.
GetVector
(
Number
<
data_size
>
{})(
Number
<
k_i
>
{})
=
lds_load
(
p_a_wave
,
a_offset
);
index_t
b_offset
=
(
k_i
+
blk_id
)
*
N
+
blk_td
;
reg_b
.
GetVector
(
Number
<
data_size
>
{})(
Number
<
k_i
>
{})
=
lds_load
(
p_b_wave
,
b_offset
);
});
// for(index_t k_i = 0; k_i < K; k_i += mfma_type.num_input_blks)
// for(index_t i = 0; i < KRepeats; ++i)
static_for
<
0
,
K
,
mfma_type
.
num_input_blks
>
{}([
&
](
auto
k_i
)
{
static_for
<
0
,
KRepeats
,
1
>
{}([
&
](
auto
i
)
{
constexpr
index_t
offset
=
k_i
*
KRepeats
+
i
;
p_c_thread
=
mfma_type
.
template
run
<
MPerXdlops
,
NPerXdlops
,
AStride
,
BStride
>(
&
pa
[(
k_i
*
KRepeats
+
i
)
*
mfma_type
.
k_base
],
&
pb
[(
k_i
*
KRepeats
+
i
)
*
mfma_type
.
k_base
],
reg_a
.
GetVector
(
Number
<
mfma_type
.
k_base
>
{})[
Number
<
offset
>
{}
],
reg_b
.
GetVector
(
Number
<
mfma_type
.
k_base
>
{})[
Number
<
offset
>
{}
],
p_c_thread
);
}
});
});
});
#endif
...
...
composable_kernel/include/utility/amd_xdlops.hpp
View file @
63bad606
...
...
@@ -132,12 +132,12 @@ intrin_mfma_f32_32x32x2f32(const float* reg_a, const float* reg_b, c_vec16_1_t::
return
reg_c
;
}
__device__
float_vec4_t
intrin_mfma_f32_16x16x4f32
(
const
float
*
reg_a
,
const
float
*
reg_b
,
__device__
float_vec4_t
intrin_mfma_f32_16x16x4f32
(
const
float
reg_a
,
const
float
reg_b
,
float_vec4_t
reg_c
)
{
reg_c
.
s4
(
Number
<
0
>
{})
=
llvm_intrin_amdgcn_mfma_f32_16x16x4f32
(
reg_a
[
0
]
,
reg_b
[
0
]
,
reg_c
.
s4
[
Number
<
0
>
{}],
0
,
0
,
0
);
llvm_intrin_amdgcn_mfma_f32_16x16x4f32
(
reg_a
,
reg_b
,
reg_c
.
s4
[
Number
<
0
>
{}],
0
,
0
,
0
);
return
reg_c
;
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment