Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
d3341a67
"git@developer.sourcefind.cn:change/sglang.git" did not exist on "d737da5f17ebd179fa9d6a79fb28e6d09398848d"
Commit
d3341a67
authored
Aug 16, 2021
by
Jing Zhang
Browse files
xdlops refactor
parent
b62bf8c3
Changes
5
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
354 additions
and
617 deletions
+354
-617
composable_kernel/include/tensor_operation/blockwise_gemm_xdlops.hpp
...kernel/include/tensor_operation/blockwise_gemm_xdlops.hpp
+186
-392
composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v2r3.hpp
...el/include/tensor_operation/gridwise_gemm_xdlops_v2r3.hpp
+84
-182
composable_kernel/include/tensor_operation/xdlops_gemm.hpp
composable_kernel/include/tensor_operation/xdlops_gemm.hpp
+74
-35
host/driver_offline/include/driver_gemm_xdlops_v2r3.hpp
host/driver_offline/include/driver_gemm_xdlops_v2r3.hpp
+9
-7
host/driver_offline/src/conv_fwd_driver_offline.cpp
host/driver_offline/src/conv_fwd_driver_offline.cpp
+1
-1
No files found.
composable_kernel/include/tensor_operation/blockwise_gemm_xdlops.hpp
View file @
d3341a67
This diff is collapsed.
Click to expand it.
composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v2r3.hpp
View file @
d3341a67
This diff is collapsed.
Click to expand it.
composable_kernel/include/tensor_operation/xdlops_gemm.hpp
View file @
d3341a67
...
@@ -709,19 +709,59 @@ struct XdlopsGemm
...
@@ -709,19 +709,59 @@ struct XdlopsGemm
static_assert
(
mfma_type
.
k
%
mfma_type
.
k_base
==
0
,
"k % kbase != 0!"
);
static_assert
(
mfma_type
.
k
%
mfma_type
.
k_base
==
0
,
"k % kbase != 0!"
);
}
}
template
<
typename
CM0N0M1N1M2N2GridDesc
>
__host__
__device__
static
constexpr
auto
MakeCM0N0M1N1M2M3M4N2GridDescriptor
(
const
CM0N0M1N1M2N2GridDesc
&
c_m0_n0_m1_n1_m2_n2_grid_desc
)
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
constexpr
auto
I4
=
Number
<
4
>
{};
constexpr
auto
I5
=
Number
<
5
>
{};
constexpr
auto
M0
=
c_m0_n0_m1_n1_m2_n2_grid_desc
.
GetLength
(
I0
);
constexpr
auto
N0
=
c_m0_n0_m1_n1_m2_n2_grid_desc
.
GetLength
(
I1
);
constexpr
auto
M1
=
c_m0_n0_m1_n1_m2_n2_grid_desc
.
GetLength
(
I2
);
constexpr
auto
N1
=
c_m0_n0_m1_n1_m2_n2_grid_desc
.
GetLength
(
I3
);
constexpr
auto
M2
=
c_m0_n0_m1_n1_m2_n2_grid_desc
.
GetLength
(
I4
);
constexpr
auto
N2
=
c_m0_n0_m1_n1_m2_n2_grid_desc
.
GetLength
(
I5
);
static_assert
(
N2
==
mfma_type
.
num_threads_blk
,
""
);
static_assert
(
M2
==
(
mfma_type
.
num_groups_blk
*
mfma_type
.
num_output_blks
*
mfma_type
.
group_size
),
""
);
return
transform_dynamic_tensor_descriptor
(
c_m0_n0_m1_n1_m2_n2_grid_desc
,
make_tuple
(
make_pass_through_transform
(
M0
),
make_pass_through_transform
(
N0
),
make_pass_through_transform
(
M1
),
make_pass_through_transform
(
N1
),
make_unmerge_transform
(
make_tuple
(
mfma_type
.
num_groups_blk
,
mfma_type
.
num_input_blks
,
mfma_type
.
group_size
)),
make_pass_through_transform
(
mfma_type
.
num_threads_blk
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{},
Sequence
<
5
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
,
5
,
6
>
{},
Sequence
<
7
>
{}));
}
__device__
static
constexpr
index_t
GetRegSizePerXdlops
()
__device__
static
constexpr
index_t
GetRegSizePerXdlops
()
{
{
return
MPerXdlops
*
NPerXdlops
/
mfma_type
.
wave_size
;
return
MPerXdlops
*
NPerXdlops
/
mfma_type
.
wave_size
;
}
}
template
<
class
ADesc
,
template
<
index_t
c_offset
,
class
FloatA
,
class
FloatB
,
class
FloatC
>
class
BDesc
,
class
CDesc
,
index_t
m0
,
index_t
n0
,
class
FloatA
,
class
FloatB
,
class
FloatC
>
__device__
void
Run
(
const
FloatA
&
p_a_wave
,
const
FloatB
&
p_b_wave
,
FloatC
&
p_c_thread
)
const
__device__
void
Run
(
const
FloatA
&
p_a_wave
,
const
FloatB
&
p_b_wave
,
FloatC
&
p_c_thread
)
const
{
{
static_assert
(
is_same
<
base_type
,
float
>::
value
||
is_same
<
base_type
,
half_t
>::
value
||
static_assert
(
is_same
<
base_type
,
float
>::
value
||
is_same
<
base_type
,
half_t
>::
value
||
...
@@ -730,24 +770,35 @@ struct XdlopsGemm
...
@@ -730,24 +770,35 @@ struct XdlopsGemm
static_assert
(
KPack
%
mfma_type
.
k_base
==
0
,
"KPack cannot be divided by k_base"
);
static_assert
(
KPack
%
mfma_type
.
k_base
==
0
,
"KPack cannot be divided by k_base"
);
constexpr
index_t
c_offset
=
CDesc
{}.
CalculateOffset
(
make_tuple
(
m0
,
n0
))
*
GetNumXdlops
();
static_for
<
0
,
KPack
/
mfma_type
.
k_base
,
1
>
{}([
&
](
auto
k
)
{
static_for
<
0
,
KPack
,
mfma_type
.
k_base
>
{}([
&
](
auto
k
)
{
constexpr
index_t
a_offset
=
ADesc
{}.
CalculateOffset
(
make_tuple
(
0
,
m0
,
0
,
k
));
constexpr
index_t
b_offset
=
BDesc
{}.
CalculateOffset
(
make_tuple
(
0
,
n0
,
0
,
k
));
mfma_type
.
template
run
<
MPerXdlops
,
NPerXdlops
,
c_offset
>(
mfma_type
.
template
run
<
MPerXdlops
,
NPerXdlops
,
c_offset
>(
p_a_wave
[
Number
<
a_offset
/
mfma_type
.
k_base
>
{}],
p_a_wave
[
k
],
p_b_wave
[
k
],
p_c_thread
);
p_b_wave
[
Number
<
b_offset
/
mfma_type
.
k_base
>
{}],
p_c_thread
);
});
});
}
}
static
constexpr
auto
GetBlkIdx
()
{
const
auto
threadidx_to_blk_idx_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_merge_transform
(
make_tuple
(
1
,
mfma_type
.
num_input_blks
,
mfma_type
.
num_threads_blk
))),
make_tuple
(
Sequence
<
0
,
1
,
2
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
const
auto
blk_idx
=
threadidx_to_blk_idx_adaptor
.
CalculateBottomIndex
(
make_multi_index
(
get_thread_local_1d_id
()));
const
auto
blk_id
=
blk_idx
[
Number
<
1
>
{}];
const
auto
blk_td
=
blk_idx
[
Number
<
2
>
{}];
return
make_tuple
(
blk_id
,
blk_td
);
}
__device__
static
CIndex
GetBeginOfThreadBlk
(
index_t
xdlops_i
,
index_t
blk_i
)
__device__
static
CIndex
GetBeginOfThreadBlk
(
index_t
xdlops_i
,
index_t
blk_i
)
{
{
const
index_t
laneId
=
get_thread_local_1d_id
()
%
mfma_type
.
wave_size
;
const
auto
blk_idx
=
GetBlkIdx
();
const
index_t
blk_id
=
laneId
/
mfma_type
.
num_threads_blk
;
const
index_t
blk_td
=
laneId
%
mfma_type
.
num_threads_blk
;
const
auto
blk_id
=
blk_idx
[
Number
<
0
>
{}];
const
auto
blk_td
=
blk_idx
[
Number
<
1
>
{}];
index_t
n_offset
=
blk_i
*
mfma_type
.
n
+
blk_td
;
index_t
n_offset
=
blk_i
*
mfma_type
.
n
+
blk_td
;
index_t
m_offset
=
xdlops_i
*
mfma_type
.
m
+
blk_id
*
mfma_type
.
group_size
;
index_t
m_offset
=
xdlops_i
*
mfma_type
.
m
+
blk_id
*
mfma_type
.
group_size
;
...
@@ -755,24 +806,12 @@ struct XdlopsGemm
...
@@ -755,24 +806,12 @@ struct XdlopsGemm
return
CIndex
{
m_offset
,
n_offset
};
return
CIndex
{
m_offset
,
n_offset
};
}
}
static
constexpr
index_t
MRepeats
=
GetXdlopsInfo
().
MRepeats
;
static
constexpr
index_t
NRepeats
=
GetXdlopsInfo
().
NRepeats
;
static
constexpr
index_t
MPerXdlops
=
GetXdlopsInfo
().
MPerXdlops
;
static
constexpr
index_t
MPerXdlops
=
GetXdlopsInfo
().
MPerXdlops
;
static
constexpr
index_t
NPerXdlops
=
GetXdlopsInfo
().
NPerXdlops
;
static
constexpr
index_t
NPerXdlops
=
GetXdlopsInfo
().
NPerXdlops
;
static
constexpr
bool
IsKReduction
=
GetXdlopsInfo
().
IsKReduction
();
static
constexpr
bool
IsABroadcast
=
GetXdlopsInfo
().
IsABroadcast
();
static
constexpr
index_t
KPerXdlops
=
GetXdlopsInfo
().
GetKPerXdlops
();
static
constexpr
index_t
KPerXdlops
=
GetXdlopsInfo
().
GetKPerXdlops
();
static
constexpr
auto
GetBlkId
(
const
index_t
lane_id
)
static
constexpr
bool
IsKReduction
=
GetXdlopsInfo
().
IsKReduction
();
{
static
constexpr
bool
IsABroadcast
=
GetXdlopsInfo
().
IsABroadcast
();
return
lane_id
/
mfma_type
.
num_threads_blk
;
}
static
constexpr
auto
GetBlkTd
(
const
index_t
lane_id
)
{
return
lane_id
%
mfma_type
.
num_threads_blk
;
}
static
constexpr
auto
mfma_type
=
GetXdlopsInfo
().
mfma_type
;
static
constexpr
auto
mfma_type
=
GetXdlopsInfo
().
mfma_type
;
...
@@ -794,7 +833,7 @@ struct XdlopsGemm
...
@@ -794,7 +833,7 @@ struct XdlopsGemm
}
}
};
};
__host__
__device__
static
constexpr
auto
GetCLayout
()
{
return
CLayout
{};
}
__host__
__device__
static
constexpr
auto
GetC
Xdlops
Layout
()
{
return
CLayout
{};
}
};
};
}
// namespace ck
}
// namespace ck
...
...
host/driver_offline/include/driver_gemm_xdlops_v2r3.hpp
View file @
d3341a67
...
@@ -129,9 +129,10 @@ __host__ float driver_gemm_xdlops_v2r3(const FloatAB* p_a_grid,
...
@@ -129,9 +129,10 @@ __host__ float driver_gemm_xdlops_v2r3(const FloatAB* p_a_grid,
"wrong! GridwiseGemm_km_kn_m0m1n0n1_xdlops_v2r3 has invalid setting"
);
"wrong! GridwiseGemm_km_kn_m0m1n0n1_xdlops_v2r3 has invalid setting"
);
}
}
const
auto
c_m0_m1_m2_n_grid_desc
=
GridwiseGemm
::
MakeCM0M1M2NGridDescriptor
(
c_m_n_grid_desc
);
const
auto
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc
=
GridwiseGemm
::
MakeCM0N0M1N1M2M3M4N2GridDescriptor
(
c_m_n_grid_desc
);
using
CM0
M1M2N
GridDesc
=
decltype
(
c_m0_
m1_m2_n
_grid_desc
);
using
CM0
N0M1N1M2M3M4N2
GridDesc
=
decltype
(
c_m0_
n0_m1_n1_m2_m3_m4_n2
_grid_desc
);
const
auto
c_block_cluster_adaptor
=
GridwiseGemm
::
MakeCBlockClusterAdaptor
(
c_m_n_grid_desc
);
const
auto
c_block_cluster_adaptor
=
GridwiseGemm
::
MakeCBlockClusterAdaptor
(
c_m_n_grid_desc
);
...
@@ -144,7 +145,7 @@ __host__ float driver_gemm_xdlops_v2r3(const FloatAB* p_a_grid,
...
@@ -144,7 +145,7 @@ __host__ float driver_gemm_xdlops_v2r3(const FloatAB* p_a_grid,
FloatC
,
FloatC
,
remove_reference_t
<
AK0MK1GridDesc
>
,
remove_reference_t
<
AK0MK1GridDesc
>
,
remove_reference_t
<
BK0NK1GridDesc
>
,
remove_reference_t
<
BK0NK1GridDesc
>
,
remove_reference_t
<
CM0
M1M2N
GridDesc
>
,
remove_reference_t
<
CM0
N0M1N1M2M3M4N2
GridDesc
>
,
remove_reference_t
<
CBlockClusterAdaptor
>>
;
remove_reference_t
<
CBlockClusterAdaptor
>>
;
#if CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE
#if CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE
...
@@ -158,18 +159,18 @@ __host__ float driver_gemm_xdlops_v2r3(const FloatAB* p_a_grid,
...
@@ -158,18 +159,18 @@ __host__ float driver_gemm_xdlops_v2r3(const FloatAB* p_a_grid,
p_c_grid
,
p_c_grid
,
a_k0_m_k1_grid_desc
,
a_k0_m_k1_grid_desc
,
b_k0_n_k1_grid_desc
,
b_k0_n_k1_grid_desc
,
c_m0_
m1_m2_n
_grid_desc
,
c_m0_
n0_m1_n1_m2_m3_m4_n2
_grid_desc
,
c_block_cluster_adaptor
);
c_block_cluster_adaptor
);
#elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER
#elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER
DeviceMem
a_k0_m_k1_grid_desc_dev_buf
(
sizeof
(
AK0MK1GridDesc
));
DeviceMem
a_k0_m_k1_grid_desc_dev_buf
(
sizeof
(
AK0MK1GridDesc
));
DeviceMem
b_k0_n_k1_grid_desc_dev_buf
(
sizeof
(
BK0NK1GridDesc
));
DeviceMem
b_k0_n_k1_grid_desc_dev_buf
(
sizeof
(
BK0NK1GridDesc
));
DeviceMem
c_m0_
m1_m2_n
_grid_desc_dev_buf
(
sizeof
(
CM0
M1M2N
GridDesc
));
DeviceMem
c_m0_
n0_m1_n1_m2_m3_m4_n2
_grid_desc_dev_buf
(
sizeof
(
CM0
N0M1N1M2M3M4N2
GridDesc
));
DeviceMem
c_block_cluster_adaptor_dev_buf
(
sizeof
(
CBlockClusterAdaptor
));
DeviceMem
c_block_cluster_adaptor_dev_buf
(
sizeof
(
CBlockClusterAdaptor
));
a_k0_m_k1_grid_desc_dev_buf
.
ToDevice
(
&
a_k0_m_k1_grid_desc
);
a_k0_m_k1_grid_desc_dev_buf
.
ToDevice
(
&
a_k0_m_k1_grid_desc
);
b_k0_n_k1_grid_desc_dev_buf
.
ToDevice
(
&
b_k0_n_k1_grid_desc
);
b_k0_n_k1_grid_desc_dev_buf
.
ToDevice
(
&
b_k0_n_k1_grid_desc
);
c_m0_
m1_m2_n
_grid_desc_dev_buf
.
ToDevice
(
&
c_m0_
m1_m2_n
_grid_desc
);
c_m0_
n0_m1_n1_m2_m3_m4_n2
_grid_desc_dev_buf
.
ToDevice
(
&
c_m0_
n0_m1_n1_m2_m3_m4_n2
_grid_desc
);
c_block_cluster_adaptor_dev_buf
.
ToDevice
(
&
c_block_cluster_adaptor
);
c_block_cluster_adaptor_dev_buf
.
ToDevice
(
&
c_block_cluster_adaptor
);
float
ave_time
=
launch_and_time_kernel
(
float
ave_time
=
launch_and_time_kernel
(
...
@@ -183,7 +184,8 @@ __host__ float driver_gemm_xdlops_v2r3(const FloatAB* p_a_grid,
...
@@ -183,7 +184,8 @@ __host__ float driver_gemm_xdlops_v2r3(const FloatAB* p_a_grid,
p_c_grid
,
p_c_grid
,
cast_pointer_to_constant_address_space
(
a_k0_m_k1_grid_desc_dev_buf
.
GetDeviceBuffer
()),
cast_pointer_to_constant_address_space
(
a_k0_m_k1_grid_desc_dev_buf
.
GetDeviceBuffer
()),
cast_pointer_to_constant_address_space
(
b_k0_n_k1_grid_desc_dev_buf
.
GetDeviceBuffer
()),
cast_pointer_to_constant_address_space
(
b_k0_n_k1_grid_desc_dev_buf
.
GetDeviceBuffer
()),
cast_pointer_to_constant_address_space
(
c_m0_m1_m2_n_grid_desc_dev_buf
.
GetDeviceBuffer
()),
cast_pointer_to_constant_address_space
(
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc_dev_buf
.
GetDeviceBuffer
()),
cast_pointer_to_constant_address_space
(
c_block_cluster_adaptor_dev_buf
.
GetDeviceBuffer
()));
cast_pointer_to_constant_address_space
(
c_block_cluster_adaptor_dev_buf
.
GetDeviceBuffer
()));
#endif
#endif
return
ave_time
;
return
ave_time
;
...
...
host/driver_offline/src/conv_fwd_driver_offline.cpp
View file @
d3341a67
...
@@ -24,7 +24,7 @@
...
@@ -24,7 +24,7 @@
#define USE_CONV_FWD_V4R4R2_NHWC 1
#define USE_CONV_FWD_V4R4R2_NHWC 1
#define USE_CONV_FWD_V6R1_NCHW 0
#define USE_CONV_FWD_V6R1_NCHW 0
#define USE_CONV_FWD_V5R1_NCHW 0
#define USE_CONV_FWD_V5R1_NCHW 0
#define USE_CONV_FWD_V4R4R2_XDL_NCHW
0
#define USE_CONV_FWD_V4R4R2_XDL_NCHW
1
#define USE_CONV_FWD_V4R4R4_XDL_NHWC 0
#define USE_CONV_FWD_V4R4R4_XDL_NHWC 0
enum
ConvForwardAlgo
enum
ConvForwardAlgo
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment