Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
d8dc850e
"test/vscode:/vscode.git/clone" did not exist on "d4e5a3ea93d1e0da847ff96420efbd40862164a9"
Commit
d8dc850e
authored
Feb 29, 2024
by
Adam Osewski
Browse files
Use buffer loads and proper cache coherence.
parent
b398481e
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
61 additions
and
16 deletions
+61
-16
test/work_scheduling/test_strided_reduction_tile_loop.cpp
test/work_scheduling/test_strided_reduction_tile_loop.cpp
+61
-16
No files found.
test/work_scheduling/test_strided_reduction_tile_loop.cpp
View file @
d8dc850e
...
@@ -48,8 +48,8 @@ struct GemmArgDesc
...
@@ -48,8 +48,8 @@ struct GemmArgDesc
template
<
index_t
MPerBlock
,
index_t
NPerBlock
,
index_t
KPerBlock
>
template
<
index_t
MPerBlock
,
index_t
NPerBlock
,
index_t
KPerBlock
>
__global__
void
grouped_gemm_naive_strided_tile_loop_reduce
(
const
GemmArgDesc
*
p_gemm_descs
,
__global__
void
grouped_gemm_naive_strided_tile_loop_reduce
(
const
GemmArgDesc
*
p_gemm_descs
,
volatile
float
*
p_workspace
,
float
*
p_workspace
,
volatile
uint32_t
*
p_flags
,
uint32_t
*
p_flags
,
index_t
tile_count
,
index_t
tile_count
,
index_t
k_batch
)
index_t
k_batch
)
{
{
...
@@ -88,9 +88,9 @@ __global__ void grouped_gemm_naive_strided_tile_loop_reduce(const GemmArgDesc* p
...
@@ -88,9 +88,9 @@ __global__ void grouped_gemm_naive_strided_tile_loop_reduce(const GemmArgDesc* p
const
index_t
N
=
p_gemm_descs
[
group_id
].
N
;
const
index_t
N
=
p_gemm_descs
[
group_id
].
N
;
const
index_t
K
=
p_gemm_descs
[
group_id
].
K
;
const
index_t
K
=
p_gemm_descs
[
group_id
].
K
;
const
auto
p_A
=
p_gemm_descs
[
group_id
].
p_A
;
auto
p_A
=
const_cast
<
float
*>
(
p_gemm_descs
[
group_id
].
p_A
)
;
const
auto
p_B
=
p_gemm_descs
[
group_id
].
p_B
;
auto
p_B
=
const_cast
<
float
*>
(
p_gemm_descs
[
group_id
].
p_B
)
;
const
auto
p_C
=
p_gemm_descs
[
group_id
].
p_C
;
auto
p_C
=
p_gemm_descs
[
group_id
].
p_C
;
const
auto
c_grid_desc_m_n
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
M
,
N
));
const
auto
c_grid_desc_m_n
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
M
,
N
));
BlockToCTileMap_LinearKSplit
<
MPerBlock
,
NPerBlock
>
b2c_tile_map
(
c_grid_desc_m_n
,
k_batch
);
BlockToCTileMap_LinearKSplit
<
MPerBlock
,
NPerBlock
>
b2c_tile_map
(
c_grid_desc_m_n
,
k_batch
);
...
@@ -124,11 +124,29 @@ __global__ void grouped_gemm_naive_strided_tile_loop_reduce(const GemmArgDesc* p
...
@@ -124,11 +124,29 @@ __global__ void grouped_gemm_naive_strided_tile_loop_reduce(const GemmArgDesc* p
const
index_t
B_k_tile_offset
=
k_batch_id
*
KPerBlock
;
const
index_t
B_k_tile_offset
=
k_batch_id
*
KPerBlock
;
const
index_t
B_thread_tile_n_idx
=
get_thread_local_1d_id
()
%
NPerBlock
;
const
index_t
B_thread_tile_n_idx
=
get_thread_local_1d_id
()
%
NPerBlock
;
auto
a_buffer_resource
=
make_wave_buffer_resource_with_default_range
<
float
>
(
p_A
+
A_m_tile_offset
*
stride_a
+
A_k_tile_offset
);
auto
b_buffer_resource
=
make_wave_buffer_resource_with_default_range
<
float
>
(
p_B
+
B_k_tile_offset
*
stride_b
+
B_n_tile_offset
);
for
(
index_t
k
=
0
;
k
<
KPerBlock
;
++
k
)
for
(
index_t
k
=
0
;
k
<
KPerBlock
;
++
k
)
{
{
partial_result
+=
float
a_val
=
llvm_amdgcn_raw_buffer_load_fp32
(
p_A
[(
A_m_tile_offset
+
A_thread_tile_m_idx
)
*
stride_a
+
A_k_tile_offset
+
k
]
*
a_buffer_resource
,
p_B
[(
B_k_tile_offset
+
k
)
*
stride_b
+
B_n_tile_offset
+
B_thread_tile_n_idx
];
(
A_thread_tile_m_idx
*
stride_a
+
k
)
*
sizeof
(
float
),
0
,
static_cast
<
index_t
>
(
AmdBufferCoherenceEnum
::
DefaultCoherence
));
float
b_val
=
llvm_amdgcn_raw_buffer_load_fp32
(
b_buffer_resource
,
(
k
*
stride_b
+
B_thread_tile_n_idx
)
*
sizeof
(
float
),
0
,
static_cast
<
index_t
>
(
AmdBufferCoherenceEnum
::
DefaultCoherence
));
partial_result
+=
a_val
*
b_val
;
// partial_result +=
// p_A[(A_m_tile_offset + A_thread_tile_m_idx) * stride_a + A_k_tile_offset + k]
// * p_B[(B_k_tile_offset + k) * stride_b + B_n_tile_offset +
// B_thread_tile_n_idx];
}
}
}
while
(
work_scheduler
.
GetNextTile
()
&&
b2c_tile_map
.
GetNextKTileIdx
());
}
while
(
work_scheduler
.
GetNextTile
()
&&
b2c_tile_map
.
GetNextKTileIdx
());
...
@@ -136,10 +154,17 @@ __global__ void grouped_gemm_naive_strided_tile_loop_reduce(const GemmArgDesc* p
...
@@ -136,10 +154,17 @@ __global__ void grouped_gemm_naive_strided_tile_loop_reduce(const GemmArgDesc* p
if
(
!
b2c_tile_map
.
IsFirstKSplitBlock
())
if
(
!
b2c_tile_map
.
IsFirstKSplitBlock
())
{
{
// Assume we have MPerBlock x NPerBlock tile per each workgroup in contiguous memory.
// Assume we have MPerBlock x NPerBlock tile per each workgroup in contiguous memory.
p_workspace
[
get_block_1d_id
()
*
MPerBlock
*
NPerBlock
+
get_thread_local_1d_id
()]
=
auto
w_buffer_resource
=
make_wave_buffer_resource_with_default_range
<
float
>
(
partial_result
;
p_workspace
+
get_block_1d_id
()
*
MPerBlock
*
NPerBlock
);
llvm_amdgcn_raw_buffer_store_fp32
(
partial_result
,
w_buffer_resource
,
get_thread_local_1d_id
()
*
sizeof
(
float
),
0
,
static_cast
<
index_t
>
(
AmdBufferCoherenceEnum
::
GLC
));
// p_workspace[get_block_1d_id() * MPerBlock * NPerBlock + get_thread_local_1d_id()] =
// partial_result;
}
}
__threadfence
();
const
index_t
output_tile_idx
=
const
index_t
output_tile_idx
=
__builtin_amdgcn_readfirstlane
(
b2c_tile_map
.
GetOutputTileIdx
());
__builtin_amdgcn_readfirstlane
(
b2c_tile_map
.
GetOutputTileIdx
());
...
@@ -158,10 +183,21 @@ __global__ void grouped_gemm_naive_strided_tile_loop_reduce(const GemmArgDesc* p
...
@@ -158,10 +183,21 @@ __global__ void grouped_gemm_naive_strided_tile_loop_reduce(const GemmArgDesc* p
// read actual flag value.
// read actual flag value.
for
(
index_t
i
=
1
;
i
<
neighbour_count
;
++
i
)
for
(
index_t
i
=
1
;
i
<
neighbour_count
;
++
i
)
{
{
partial_result
+=
p_workspace
[(
get_block_1d_id
())
*
MPerBlock
*
NPerBlock
+
// partial_result += p_workspace[(get_block_1d_id()) * MPerBlock * NPerBlock +
i
*
MPerBlock
*
NPerBlock
+
get_thread_local_1d_id
()];
// i * MPerBlock * NPerBlock +
// get_thread_local_1d_id()];
auto
w_buffer_resource
=
make_wave_buffer_resource_with_default_range
<
float
>
(
p_workspace
+
get_block_1d_id
()
*
MPerBlock
*
NPerBlock
+
i
*
MPerBlock
*
NPerBlock
);
float
value
=
llvm_amdgcn_raw_buffer_load_fp32
(
w_buffer_resource
,
get_thread_local_1d_id
()
*
sizeof
(
float
),
0
,
static_cast
<
index_t
>
(
AmdBufferCoherenceEnum
::
GLC
));
partial_result
+=
value
;
}
}
__threadfence
();
// Signal waiting blocks that they can start use their workspace.
// Signal waiting blocks that they can start use their workspace.
work_scheduler
.
Reset
(
k_batch
,
output_tile_idx
,
output_tile_idx_offset
);
work_scheduler
.
Reset
(
k_batch
,
output_tile_idx
,
output_tile_idx_offset
);
...
@@ -171,8 +207,17 @@ __global__ void grouped_gemm_naive_strided_tile_loop_reduce(const GemmArgDesc* p
...
@@ -171,8 +207,17 @@ __global__ void grouped_gemm_naive_strided_tile_loop_reduce(const GemmArgDesc* p
const
index_t
C_n_tile_offset
=
block_n_id
*
NPerBlock
;
const
index_t
C_n_tile_offset
=
block_n_id
*
NPerBlock
;
const
index_t
C_thread_tile_n_idx
=
get_thread_local_1d_id
()
%
NPerBlock
;
const
index_t
C_thread_tile_n_idx
=
get_thread_local_1d_id
()
%
NPerBlock
;
p_C
[(
C_m_tile_offset
+
C_thread_tile_m_idx
)
*
stride_c
+
C_n_tile_offset
+
auto
c_buffer_resource
=
make_wave_buffer_resource_with_default_range
<
float
>
(
C_thread_tile_n_idx
]
=
partial_result
;
p_C
+
C_m_tile_offset
*
stride_c
+
C_n_tile_offset
);
llvm_amdgcn_raw_buffer_store_fp32
(
partial_result
,
c_buffer_resource
,
(
C_thread_tile_m_idx
*
stride_c
+
C_thread_tile_n_idx
)
*
sizeof
(
float
),
0
,
static_cast
<
index_t
>
(
AmdBufferCoherenceEnum
::
DefaultCoherence
));
// p_C[(C_m_tile_offset + C_thread_tile_m_idx) * stride_c + C_n_tile_offset +
// C_thread_tile_n_idx] = partial_result;
}
}
else
if
(
work_scheduler
.
HasTile
())
else
if
(
work_scheduler
.
HasTile
())
{
{
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment