Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
81ea7c0a
"docs/source/en/quicktour.mdx" did not exist on "86ac3ea1d7fbc2f40bb0ae3dc9f045ed78c6fe2e"
Commit
81ea7c0a
authored
May 09, 2023
by
Po-Yen, Chen
Browse files
Report timestamp of gridwise gemm
parent
4a26559e
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
128 additions
and
6 deletions
+128
-6
include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp
...k/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp
+63
-3
include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v2.hpp
...k/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v2.hpp
+21
-1
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v1.hpp
...nsor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v1.hpp
+22
-1
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp
...k/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp
+22
-1
No files found.
include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp
View file @
81ea7c0a
...
@@ -54,8 +54,14 @@ struct GridwiseGemmPipeline_v1<1>
...
@@ -54,8 +54,14 @@ struct GridwiseGemmPipeline_v1<1>
const
BBlockTransferStep
&
b_block_copy_step
,
const
BBlockTransferStep
&
b_block_copy_step
,
const
BlockwiseGemm
&
blockwise_gemm
,
const
BlockwiseGemm
&
blockwise_gemm
,
CThreadBuffer
&
c_thread_buf
,
CThreadBuffer
&
c_thread_buf
,
index_t
num_loop
)
index_t
num_loop
,
long
&
loop_start
,
long
&
loop_end
)
{
{
__builtin_amdgcn_sched_barrier
(
0
);
asm
volatile
(
"; [POYENC] pipeline start"
::
);
__builtin_amdgcn_sched_barrier
(
0
);
// preload data into LDS
// preload data into LDS
a_blockwise_copy
.
RunRead
(
a_grid_desc
,
a_grid_buf
);
a_blockwise_copy
.
RunRead
(
a_grid_desc
,
a_grid_buf
);
b_blockwise_copy
.
RunRead
(
b_grid_desc
,
b_grid_buf
);
b_blockwise_copy
.
RunRead
(
b_grid_desc
,
b_grid_buf
);
...
@@ -69,6 +75,11 @@ struct GridwiseGemmPipeline_v1<1>
...
@@ -69,6 +75,11 @@ struct GridwiseGemmPipeline_v1<1>
a_blockwise_copy
.
RunWrite
(
a_block_desc
,
a_block_buf
);
a_blockwise_copy
.
RunWrite
(
a_block_desc
,
a_block_buf
);
b_blockwise_copy
.
RunWrite
(
b_block_desc
,
b_block_buf
);
b_blockwise_copy
.
RunWrite
(
b_block_desc
,
b_block_buf
);
__builtin_amdgcn_sched_barrier
(
0
);
loop_start
=
__builtin_readcyclecounter
();
asm
volatile
(
"; [POYENC] hot-loop start"
::
);
__builtin_amdgcn_sched_barrier
(
0
);
// main body
// main body
if
constexpr
(
HasMainLoop
)
if
constexpr
(
HasMainLoop
)
{
{
...
@@ -102,6 +113,15 @@ struct GridwiseGemmPipeline_v1<1>
...
@@ -102,6 +113,15 @@ struct GridwiseGemmPipeline_v1<1>
blockwise_gemm
.
Run
(
a_block_buf
,
b_block_buf
,
c_thread_buf
);
blockwise_gemm
.
Run
(
a_block_buf
,
b_block_buf
,
c_thread_buf
);
}
}
__builtin_amdgcn_sched_barrier
(
0
);
loop_end
=
__builtin_readcyclecounter
();
asm
volatile
(
"; [POYENC] hot-loop end"
::
);
__builtin_amdgcn_sched_barrier
(
0
);
__builtin_amdgcn_sched_barrier
(
0
);
asm
volatile
(
"; [POYENC] pipeline end"
::
);
__builtin_amdgcn_sched_barrier
(
0
);
}
}
};
};
...
@@ -152,8 +172,14 @@ struct GridwiseGemmPipeline_v1<2>
...
@@ -152,8 +172,14 @@ struct GridwiseGemmPipeline_v1<2>
const
BBlockTransferStep
&
b_block_copy_step
,
const
BBlockTransferStep
&
b_block_copy_step
,
const
BlockwiseGemm
&
blockwise_gemm
,
const
BlockwiseGemm
&
blockwise_gemm
,
CThreadBuffer
&
c_thread_buf
,
CThreadBuffer
&
c_thread_buf
,
index_t
num_loop
)
index_t
num_loop
,
long
&
loop_start
,
long
&
loop_end
)
{
{
__builtin_amdgcn_sched_barrier
(
0
);
asm
volatile
(
"; [POYENC] pipeline start"
::
);
__builtin_amdgcn_sched_barrier
(
0
);
// preload data into LDS
// preload data into LDS
{
{
// Read 0
// Read 0
...
@@ -172,6 +198,11 @@ struct GridwiseGemmPipeline_v1<2>
...
@@ -172,6 +198,11 @@ struct GridwiseGemmPipeline_v1<2>
// Initialize C
// Initialize C
c_thread_buf
.
Clear
();
c_thread_buf
.
Clear
();
__builtin_amdgcn_sched_barrier
(
0
);
loop_start
=
__builtin_readcyclecounter
();
asm
volatile
(
"; [POYENC] hot-loop start"
::
);
__builtin_amdgcn_sched_barrier
(
0
);
// main body
// main body
if
constexpr
(
HasMainLoop
)
if
constexpr
(
HasMainLoop
)
{
{
...
@@ -250,6 +281,15 @@ struct GridwiseGemmPipeline_v1<2>
...
@@ -250,6 +281,15 @@ struct GridwiseGemmPipeline_v1<2>
// Gemm num_loop - 1
// Gemm num_loop - 1
blockwise_gemm
.
Run
(
a_block_buf
,
b_block_buf
,
c_thread_buf
);
blockwise_gemm
.
Run
(
a_block_buf
,
b_block_buf
,
c_thread_buf
);
}
}
__builtin_amdgcn_sched_barrier
(
0
);
loop_end
=
__builtin_readcyclecounter
();
asm
volatile
(
"; [POYENC] hot-loop end"
::
);
__builtin_amdgcn_sched_barrier
(
0
);
__builtin_amdgcn_sched_barrier
(
0
);
asm
volatile
(
"; [POYENC] pipeline end"
::
);
__builtin_amdgcn_sched_barrier
(
0
);
}
}
};
};
...
@@ -295,8 +335,14 @@ struct GridwiseGemmPipelineInterwave_v1<1>
...
@@ -295,8 +335,14 @@ struct GridwiseGemmPipelineInterwave_v1<1>
const
BBlockTransferStep
&
b_block_copy_step
,
const
BBlockTransferStep
&
b_block_copy_step
,
const
BlockwiseGemm
&
blockwise_gemm
,
const
BlockwiseGemm
&
blockwise_gemm
,
CThreadBuffer
&
c_thread_buf
,
CThreadBuffer
&
c_thread_buf
,
index_t
num_loop
)
index_t
num_loop
,
long
&
loop_start
,
long
&
loop_end
)
{
{
__builtin_amdgcn_sched_barrier
(
0
);
asm
volatile
(
"; [POYENC] pipeline start"
::
);
__builtin_amdgcn_sched_barrier
(
0
);
// preload data into LDS
// preload data into LDS
a_blockwise_copy
.
RunRead
(
a_grid_desc
,
a_grid_buf
);
a_blockwise_copy
.
RunRead
(
a_grid_desc
,
a_grid_buf
);
b_blockwise_copy
.
RunRead
(
b_grid_desc
,
b_grid_buf
);
b_blockwise_copy
.
RunRead
(
b_grid_desc
,
b_grid_buf
);
...
@@ -310,6 +356,11 @@ struct GridwiseGemmPipelineInterwave_v1<1>
...
@@ -310,6 +356,11 @@ struct GridwiseGemmPipelineInterwave_v1<1>
a_blockwise_copy
.
RunWrite
(
a_block_desc
,
a_block_buf
);
a_blockwise_copy
.
RunWrite
(
a_block_desc
,
a_block_buf
);
b_blockwise_copy
.
RunWrite
(
b_block_desc
,
b_block_buf
);
b_blockwise_copy
.
RunWrite
(
b_block_desc
,
b_block_buf
);
__builtin_amdgcn_sched_barrier
(
0
);
loop_start
=
__builtin_readcyclecounter
();
asm
volatile
(
"; [POYENC] hot-loop start"
::
);
__builtin_amdgcn_sched_barrier
(
0
);
// main body
// main body
if
constexpr
(
HasMainLoop
)
if
constexpr
(
HasMainLoop
)
{
{
...
@@ -343,6 +394,15 @@ struct GridwiseGemmPipelineInterwave_v1<1>
...
@@ -343,6 +394,15 @@ struct GridwiseGemmPipelineInterwave_v1<1>
blockwise_gemm
.
Run
(
a_block_buf
,
b_block_buf
,
c_thread_buf
);
blockwise_gemm
.
Run
(
a_block_buf
,
b_block_buf
,
c_thread_buf
);
}
}
__builtin_amdgcn_sched_barrier
(
0
);
loop_end
=
__builtin_readcyclecounter
();
asm
volatile
(
"; [POYENC] hot-loop end"
::
);
__builtin_amdgcn_sched_barrier
(
0
);
__builtin_amdgcn_sched_barrier
(
0
);
asm
volatile
(
"; [POYENC] pipeline end"
::
);
__builtin_amdgcn_sched_barrier
(
0
);
}
}
};
};
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v2.hpp
View file @
81ea7c0a
...
@@ -49,8 +49,14 @@ struct GridwiseGemmPipeline_v2
...
@@ -49,8 +49,14 @@ struct GridwiseGemmPipeline_v2
const
BBlockTransferStep
&
b_block_copy_step
,
const
BBlockTransferStep
&
b_block_copy_step
,
const
BlockwiseGemm
&
blockwise_gemm
,
const
BlockwiseGemm
&
blockwise_gemm
,
CThreadBuffer
&
c_thread_buf
,
CThreadBuffer
&
c_thread_buf
,
index_t
num_loop
)
index_t
num_loop
,
long
&
loop_start
,
long
&
loop_end
)
{
{
__builtin_amdgcn_sched_barrier
(
0
);
asm
volatile
(
"; [POYENC] pipeline start"
::
);
__builtin_amdgcn_sched_barrier
(
0
);
// global read 0
// global read 0
a_blockwise_copy
.
RunRead
(
a_grid_desc
,
a_grid_buf
);
a_blockwise_copy
.
RunRead
(
a_grid_desc
,
a_grid_buf
);
b_blockwise_copy
.
RunRead
(
b_grid_desc
,
b_grid_buf
);
b_blockwise_copy
.
RunRead
(
b_grid_desc
,
b_grid_buf
);
...
@@ -72,6 +78,11 @@ struct GridwiseGemmPipeline_v2
...
@@ -72,6 +78,11 @@ struct GridwiseGemmPipeline_v2
// global Read 1
// global Read 1
b_blockwise_copy
.
RunRead
(
b_grid_desc
,
b_grid_buf
);
b_blockwise_copy
.
RunRead
(
b_grid_desc
,
b_grid_buf
);
__builtin_amdgcn_sched_barrier
(
0
);
loop_start
=
__builtin_readcyclecounter
();
asm
volatile
(
"; [POYENC] hot-loop start"
::
);
__builtin_amdgcn_sched_barrier
(
0
);
// main body
// main body
if
constexpr
(
HasMainLoop
)
if
constexpr
(
HasMainLoop
)
{
{
...
@@ -122,6 +133,15 @@ struct GridwiseGemmPipeline_v2
...
@@ -122,6 +133,15 @@ struct GridwiseGemmPipeline_v2
// GEMM num_loop - 1
// GEMM num_loop - 1
blockwise_gemm
.
Run
(
a_block_buf
,
b_block_buf
,
c_thread_buf
);
blockwise_gemm
.
Run
(
a_block_buf
,
b_block_buf
,
c_thread_buf
);
}
}
__builtin_amdgcn_sched_barrier
(
0
);
loop_end
=
__builtin_readcyclecounter
();
asm
volatile
(
"; [POYENC] hot-loop end"
::
);
__builtin_amdgcn_sched_barrier
(
0
);
__builtin_amdgcn_sched_barrier
(
0
);
asm
volatile
(
"; [POYENC] pipeline end"
::
);
__builtin_amdgcn_sched_barrier
(
0
);
}
}
};
};
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v1.hpp
View file @
81ea7c0a
...
@@ -292,6 +292,11 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
...
@@ -292,6 +292,11 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
c_grid_desc_mblock_mperblock_nblock_nperblock
,
c_grid_desc_mblock_mperblock_nblock_nperblock
,
const
Block2CTileMap
&
block_2_ctile_map
)
const
Block2CTileMap
&
block_2_ctile_map
)
{
{
__builtin_amdgcn_sched_barrier
(
0
);
const
long
kernel_start
=
__builtin_readcyclecounter
();
asm
volatile
(
"; [POYENC] kernel start"
::
);
__builtin_amdgcn_sched_barrier
(
0
);
const
auto
a_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
const
auto
a_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_a_grid
,
a_grid_desc_ak0_m_ak1
.
GetElementSpaceSize
());
p_a_grid
,
a_grid_desc_ak0_m_ak1
.
GetElementSpaceSize
());
const
auto
b_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
const
auto
b_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
...
@@ -436,6 +441,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
...
@@ -436,6 +441,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
(
a_grid_desc_ak0_m_ak1
.
GetLength
(
I0
)
*
a_grid_desc_ak0_m_ak1
.
GetLength
(
I2
))
/
(
a_grid_desc_ak0_m_ak1
.
GetLength
(
I0
)
*
a_grid_desc_ak0_m_ak1
.
GetLength
(
I2
))
/
KPerBlock
);
KPerBlock
);
long
loop_start
=
0
,
loop_end
=
0
;
gridwise_gemm_pipeline
.
template
Run
<
HasMainKBlockLoop
>(
a_grid_desc_ak0_m_ak1
,
gridwise_gemm_pipeline
.
template
Run
<
HasMainKBlockLoop
>(
a_grid_desc_ak0_m_ak1
,
a_block_desc_ak0_m_ak1
,
a_block_desc_ak0_m_ak1
,
a_blockwise_copy
,
a_blockwise_copy
,
...
@@ -450,7 +456,9 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
...
@@ -450,7 +456,9 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
b_block_slice_copy_step
,
b_block_slice_copy_step
,
blockwise_gemm
,
blockwise_gemm
,
c_thread_buf
,
c_thread_buf
,
num_k_block_main_loop
);
num_k_block_main_loop
,
loop_start
,
loop_end
);
// shuffle C and write out
// shuffle C and write out
{
{
...
@@ -647,6 +655,19 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
...
@@ -647,6 +655,19 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
c_grid_desc_mblock_mperblock_nblock_nperblock
,
c_global_step
);
c_grid_desc_mblock_mperblock_nblock_nperblock
,
c_global_step
);
}
}
});
});
__builtin_amdgcn_sched_barrier
(
0
);
const
long
kernel_end
=
__builtin_readcyclecounter
();
asm
volatile
(
"; [POYENC] kernel end"
::
);
__builtin_amdgcn_sched_barrier
(
0
);
if
(
blockIdx
.
x
==
0
&&
threadIdx
.
x
==
0
)
{
printf
(
"[POYENC] prolog: %ld, hot-loop: %ld, epilog: %ld
\n
"
,
loop_start
-
kernel_start
,
loop_end
-
loop_start
,
kernel_end
-
loop_end
);
}
}
}
}
}
};
};
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp
View file @
81ea7c0a
...
@@ -331,6 +331,11 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
...
@@ -331,6 +331,11 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
const
CElementwiseOperation
&
c_element_op
,
const
CElementwiseOperation
&
c_element_op
,
const
Block2CTileMap
&
block_2_ctile_map
)
const
Block2CTileMap
&
block_2_ctile_map
)
{
{
__builtin_amdgcn_sched_barrier
(
0
);
const
long
kernel_start
=
__builtin_readcyclecounter
();
asm
volatile
(
"; [POYENC] kernel start"
::
);
__builtin_amdgcn_sched_barrier
(
0
);
const
auto
a_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
const
auto
a_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_a_grid
,
a_grid_desc_k0_m_k1
.
GetElementSpaceSize
());
p_a_grid
,
a_grid_desc_k0_m_k1
.
GetElementSpaceSize
());
const
auto
b_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
const
auto
b_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
...
@@ -469,6 +474,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
...
@@ -469,6 +474,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
// gridwise GEMM pipeline
// gridwise GEMM pipeline
const
index_t
num_k_block_main_loop
=
__builtin_amdgcn_readfirstlane
(
K0
/
K0PerBlock
);
const
index_t
num_k_block_main_loop
=
__builtin_amdgcn_readfirstlane
(
K0
/
K0PerBlock
);
long
loop_start
=
0
,
loop_end
=
0
;
GridwiseGemmPipe
::
template
Run
<
HasMainKBlockLoop
>(
a_grid_desc_k0_m_k1
,
GridwiseGemmPipe
::
template
Run
<
HasMainKBlockLoop
>(
a_grid_desc_k0_m_k1
,
a_block_desc_k0_m_k1
,
a_block_desc_k0_m_k1
,
a_blockwise_copy
,
a_blockwise_copy
,
...
@@ -483,7 +489,9 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
...
@@ -483,7 +489,9 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
b_block_slice_copy_step
,
b_block_slice_copy_step
,
blockwise_gemm
,
blockwise_gemm
,
c_thread_buf
,
c_thread_buf
,
num_k_block_main_loop
);
num_k_block_main_loop
,
loop_start
,
loop_end
);
// output: register to global memory
// output: register to global memory
{
{
...
@@ -561,6 +569,19 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
...
@@ -561,6 +569,19 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
c_thread_buf
,
c_thread_buf
,
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2
,
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2
,
c_grid_buf
);
c_grid_buf
);
__builtin_amdgcn_sched_barrier
(
0
);
const
long
kernel_end
=
__builtin_readcyclecounter
();
asm
volatile
(
"; [POYENC] kernel end"
::
);
__builtin_amdgcn_sched_barrier
(
0
);
if
(
blockIdx
.
x
==
0
&&
threadIdx
.
x
==
0
)
{
printf
(
"[POYENC] prolog: %ld, hot-loop: %ld, epilog: %ld
\n
"
,
loop_start
-
kernel_start
,
loop_end
-
loop_start
,
kernel_end
-
loop_end
);
}
}
}
}
}
};
};
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment