Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
26f221eb
"test/contraction/test_contraction_interface.cpp" did not exist on "3eee1b9b8fa13d044509089c7fc8186f4439d412"
Commit
26f221eb
authored
Nov 29, 2024
by
rocking
Browse files
Support Pure quant kernel
parent
bb652696
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
97 additions
and
47 deletions
+97
-47
example/ck_tile/12_smoothquant/example_smoothquant.cpp
example/ck_tile/12_smoothquant/example_smoothquant.cpp
+8
-3
example/ck_tile/12_smoothquant/instances/smoothquant_instance_common.hpp
.../12_smoothquant/instances/smoothquant_instance_common.hpp
+2
-1
example/ck_tile/14_moe_smoothquant/instances/moe_smoothquant_instance_common.hpp
...smoothquant/instances/moe_smoothquant_instance_common.hpp
+2
-1
include/ck_tile/ops/smoothquant/kernel/smoothquant_kernel.hpp
...ude/ck_tile/ops/smoothquant/kernel/smoothquant_kernel.hpp
+18
-11
include/ck_tile/ops/smoothquant/pipeline/smoothquant_pipeline_one_pass.hpp
...ps/smoothquant/pipeline/smoothquant_pipeline_one_pass.hpp
+19
-10
include/ck_tile/ops/smoothquant/pipeline/smoothquant_pipeline_problem.hpp
...ops/smoothquant/pipeline/smoothquant_pipeline_problem.hpp
+3
-1
include/ck_tile/ops/smoothquant/pipeline/smoothquant_pipeline_two_pass.hpp
...ps/smoothquant/pipeline/smoothquant_pipeline_two_pass.hpp
+45
-20
No files found.
example/ck_tile/12_smoothquant/example_smoothquant.cpp
View file @
26f221eb
...
@@ -89,6 +89,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
...
@@ -89,6 +89,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
xscale_buf
.
ToDevice
(
xscale_host
.
data
());
xscale_buf
.
ToDevice
(
xscale_host
.
data
());
constexpr
bool
kTwoPass
=
true
;
constexpr
bool
kTwoPass
=
true
;
constexpr
bool
kSmoothX
=
true
;
using
BlockWarps
=
ck_tile
::
sequence
<
2
,
2
>
;
using
BlockWarps
=
ck_tile
::
sequence
<
2
,
2
>
;
using
BlockTile
=
ck_tile
::
sequence
<
2
,
128
>
;
using
BlockTile
=
ck_tile
::
sequence
<
2
,
128
>
;
...
@@ -103,7 +104,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
...
@@ -103,7 +104,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
QYDataType
,
QYDataType
,
Shape
,
Shape
,
true
,
true
,
kTwoPass
>
;
kTwoPass
,
kSmoothX
>
;
using
OnePassPipeline
=
ck_tile
::
SmoothquantPipelineOnePass
<
Problem
>
;
using
OnePassPipeline
=
ck_tile
::
SmoothquantPipelineOnePass
<
Problem
>
;
using
TwoPassPipeline
=
ck_tile
::
SmoothquantPipelineTwoPass
<
Problem
>
;
using
TwoPassPipeline
=
ck_tile
::
SmoothquantPipelineTwoPass
<
Problem
>
;
...
@@ -141,8 +143,11 @@ bool run(const ck_tile::ArgParser& arg_parser)
...
@@ -141,8 +143,11 @@ bool run(const ck_tile::ArgParser& arg_parser)
for
(
int
m_
=
0
;
m_
<
m
;
++
m_
)
for
(
int
m_
=
0
;
m_
<
m
;
++
m_
)
{
{
auto
v_x
=
ck_tile
::
type_convert
<
ComputeDataType
>
(
x_host
(
m_
,
n_
));
auto
v_x
=
ck_tile
::
type_convert
<
ComputeDataType
>
(
x_host
(
m_
,
n_
));
y_host
(
m_
,
n_
)
=
v_x
*
v_xscale
;
if
constexpr
(
kSmoothX
)
y_host
(
m_
,
n_
)
=
v_x
*
v_xscale
;
else
y_host
(
m_
,
n_
)
=
v_x
;
}
}
};
};
...
...
example/ck_tile/12_smoothquant/instances/smoothquant_instance_common.hpp
View file @
26f221eb
...
@@ -41,7 +41,8 @@ float smoothquant_(const S& s, A a)
...
@@ -41,7 +41,8 @@ float smoothquant_(const S& s, A a)
typename
SmoothquantTypeConfig
<
DataType
>::
QYDataType
,
typename
SmoothquantTypeConfig
<
DataType
>::
QYDataType
,
typename
Traits_
::
Shape
,
typename
Traits_
::
Shape
,
Traits_
::
kPadN
,
Traits_
::
kPadN
,
Traits_
::
kTwoPass
>
;
Traits_
::
kTwoPass
,
true
>
;
using
OnePassPipeline
=
ck_tile
::
SmoothquantPipelineOnePass
<
PipelineProblem
>
;
using
OnePassPipeline
=
ck_tile
::
SmoothquantPipelineOnePass
<
PipelineProblem
>
;
using
TwoPassPipeline
=
ck_tile
::
SmoothquantPipelineTwoPass
<
PipelineProblem
>
;
using
TwoPassPipeline
=
ck_tile
::
SmoothquantPipelineTwoPass
<
PipelineProblem
>
;
...
...
example/ck_tile/14_moe_smoothquant/instances/moe_smoothquant_instance_common.hpp
View file @
26f221eb
...
@@ -41,7 +41,8 @@ float moe_smoothquant_(const S& s, A a)
...
@@ -41,7 +41,8 @@ float moe_smoothquant_(const S& s, A a)
typename
MoeSmoothquantTypeConfig
<
DataType
>::
QYDataType
,
typename
MoeSmoothquantTypeConfig
<
DataType
>::
QYDataType
,
typename
Traits_
::
Shape
,
typename
Traits_
::
Shape
,
Traits_
::
kPadN
,
Traits_
::
kPadN
,
Traits_
::
kTwoPass
>
;
Traits_
::
kTwoPass
,
true
>
;
using
OnePassPipeline
=
ck_tile
::
SmoothquantPipelineOnePass
<
PipelineProblem
>
;
using
OnePassPipeline
=
ck_tile
::
SmoothquantPipelineOnePass
<
PipelineProblem
>
;
using
TwoPassPipeline
=
ck_tile
::
SmoothquantPipelineTwoPass
<
PipelineProblem
>
;
using
TwoPassPipeline
=
ck_tile
::
SmoothquantPipelineTwoPass
<
PipelineProblem
>
;
...
...
include/ck_tile/ops/smoothquant/kernel/smoothquant_kernel.hpp
View file @
26f221eb
...
@@ -40,6 +40,7 @@ struct Smoothquant
...
@@ -40,6 +40,7 @@ struct Smoothquant
static
constexpr
bool
kPadM
=
false
;
// always no need to pad along M
static
constexpr
bool
kPadM
=
false
;
// always no need to pad along M
static
constexpr
bool
kPadN
=
Problem
::
kPadN
;
static
constexpr
bool
kPadN
=
Problem
::
kPadN
;
static
constexpr
bool
kTwoPass
=
Problem
::
kTwoPass
;
static
constexpr
bool
kTwoPass
=
Problem
::
kTwoPass
;
static
constexpr
bool
kSmoothX
=
Problem
::
kSmoothX
;
static
constexpr
index_t
ThreadPerWarp_N
=
Problem
::
BlockShape
::
ThreadPerWarp_N
;
static
constexpr
index_t
ThreadPerWarp_N
=
Problem
::
BlockShape
::
ThreadPerWarp_N
;
static
constexpr
index_t
Vector_N
=
Problem
::
BlockShape
::
Vector_N
;
static
constexpr
index_t
Vector_N
=
Problem
::
BlockShape
::
Vector_N
;
...
@@ -95,6 +96,7 @@ struct Smoothquant
...
@@ -95,6 +96,7 @@ struct Smoothquant
std
::
string
n
;
std
::
string
n
;
if
(
kPadN
)
n
+=
"_pn"
;
if
(
kPadN
)
n
+=
"_pn"
;
if
(
kTwoPass
)
n
+=
"_2p"
;
if
(
kTwoPass
)
n
+=
"_2p"
;
if
(
kSmoothX
)
n
+=
"_sx"
;
return
n
;
}();
return
n
;
}();
#define _SS_ std::string
#define _SS_ std::string
...
@@ -127,17 +129,22 @@ struct Smoothquant
...
@@ -127,17 +129,22 @@ struct Smoothquant
}();
}();
const
auto
xscale_window
=
[
&
]()
{
const
auto
xscale_window
=
[
&
]()
{
const
auto
tmp_
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
if
constexpr
(
kSmoothX
)
static_cast
<
const
XScaleDataType
*>
(
kargs
.
p_xscale
),
{
make_tuple
(
kargs
.
n
),
const
auto
tmp_
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
make_tuple
(
1
),
static_cast
<
const
XScaleDataType
*>
(
kargs
.
p_xscale
),
number
<
Vector_N
>
{},
make_tuple
(
kargs
.
n
),
number
<
1
>
{});
make_tuple
(
1
),
number
<
Vector_N
>
{},
const
auto
tmp2_
=
number
<
1
>
{});
pad_tensor_view
(
tmp_
,
make_tuple
(
number
<
Block_N
>
{}),
sequence
<
kPadN
>
{});
const
auto
tmp2_
=
return
make_tile_window
(
tmp2_
,
make_tuple
(
number
<
Block_N
>
{}),
{
0
});
pad_tensor_view
(
tmp_
,
make_tuple
(
number
<
Block_N
>
{}),
sequence
<
kPadN
>
{});
return
make_tile_window
(
tmp2_
,
make_tuple
(
number
<
Block_N
>
{}),
{
0
});
}
else
return
make_null_tile_window
(
make_tuple
(
number
<
Block_N
>
{}));
}();
}();
auto
yscale_window
=
[
&
]()
{
auto
yscale_window
=
[
&
]()
{
...
...
include/ck_tile/ops/smoothquant/pipeline/smoothquant_pipeline_one_pass.hpp
View file @
26f221eb
...
@@ -23,9 +23,10 @@ struct SmoothquantPipelineOnePass
...
@@ -23,9 +23,10 @@ struct SmoothquantPipelineOnePass
using
YScaleDataType
=
ck_tile
::
remove_cvref_t
<
typename
Problem
::
YScaleDataType
>
;
using
YScaleDataType
=
ck_tile
::
remove_cvref_t
<
typename
Problem
::
YScaleDataType
>
;
static
constexpr
bool
kNeedCrossWarpSync
=
Problem
::
kNeedCrossWarpSync
;
static
constexpr
bool
kNeedCrossWarpSync
=
Problem
::
kNeedCrossWarpSync
;
static
constexpr
bool
kPadM
=
false
;
//
TODO - BlockSmoothquantProblem::kP
adM
static
constexpr
bool
kPadM
=
false
;
//
No need to p
ad
M
static
constexpr
bool
kPadN
=
Problem
::
kPadN
;
static
constexpr
bool
kPadN
=
Problem
::
kPadN
;
static
constexpr
bool
UseMax3
=
true
;
// TODO - Move to trait
static
constexpr
bool
kSmoothX
=
Problem
::
kSmoothX
;
static
constexpr
bool
UseMax3
=
true
;
// TODO - Move to Problem
static
constexpr
const
char
*
name
=
[]()
{
static
constexpr
const
char
*
name
=
[]()
{
if
constexpr
(
kNeedCrossWarpSync
)
if
constexpr
(
kNeedCrossWarpSync
)
...
@@ -67,14 +68,22 @@ struct SmoothquantPipelineOnePass
...
@@ -67,14 +68,22 @@ struct SmoothquantPipelineOnePass
auto
block_reduce2d_cross_warp_sync
=
auto
block_reduce2d_cross_warp_sync
=
Policy
::
template
GetBlockReduce2dCrossWarpSync
<
Problem
>();
Policy
::
template
GetBlockReduce2dCrossWarpSync
<
Problem
>();
const
auto
x
=
load_tile
(
x_window
);
const
auto
x
=
load_tile
(
x_window
);
const
auto
xscale
=
load_tile
(
xscale_window
);
auto
y
=
tile_elementwise_in
(
auto
y
=
[
&
]()
{
[
&
](
const
auto
&
a
,
const
auto
&
b
)
{
if
constexpr
(
kSmoothX
)
return
type_convert
<
ComputeDataType
>
(
a
)
*
type_convert
<
ComputeDataType
>
(
b
);
{
},
const
auto
xscale
=
load_tile
(
xscale_window
);
x
,
return
tile_elementwise_in
(
xscale
);
[
&
](
const
auto
&
a
,
const
auto
&
b
)
{
return
type_convert
<
ComputeDataType
>
(
a
)
*
type_convert
<
ComputeDataType
>
(
b
);
},
x
,
xscale
);
}
else
return
cast_tile
<
ComputeDataType
>
(
x
);
}();
// compute absmax, cross-lane->cross-warp
// compute absmax, cross-lane->cross-warp
auto
absmax
=
[
&
]()
{
auto
absmax
=
[
&
]()
{
...
...
include/ck_tile/ops/smoothquant/pipeline/smoothquant_pipeline_problem.hpp
View file @
26f221eb
...
@@ -15,7 +15,8 @@ template <typename XDataType_,
...
@@ -15,7 +15,8 @@ template <typename XDataType_,
typename
QYDataType_
,
typename
QYDataType_
,
typename
BlockShape_
,
typename
BlockShape_
,
bool
kPadN_
,
bool
kPadN_
,
bool
kTwoPass_
>
bool
kTwoPass_
,
bool
kSmoothX_
>
struct
SmoothquantPipelineProblem
struct
SmoothquantPipelineProblem
{
{
using
XDataType
=
remove_cvref_t
<
XDataType_
>
;
using
XDataType
=
remove_cvref_t
<
XDataType_
>
;
...
@@ -30,6 +31,7 @@ struct SmoothquantPipelineProblem
...
@@ -30,6 +31,7 @@ struct SmoothquantPipelineProblem
static
constexpr
bool
kPadN
=
kPadN_
;
static
constexpr
bool
kPadN
=
kPadN_
;
static
constexpr
bool
kTwoPass
=
kTwoPass_
;
static
constexpr
bool
kTwoPass
=
kTwoPass_
;
static
constexpr
bool
kSmoothX
=
kSmoothX_
;
};
};
}
// namespace ck_tile
}
// namespace ck_tile
include/ck_tile/ops/smoothquant/pipeline/smoothquant_pipeline_two_pass.hpp
View file @
26f221eb
...
@@ -23,8 +23,9 @@ struct SmoothquantPipelineTwoPass
...
@@ -23,8 +23,9 @@ struct SmoothquantPipelineTwoPass
using
YScaleDataType
=
ck_tile
::
remove_cvref_t
<
typename
Problem
::
YScaleDataType
>
;
using
YScaleDataType
=
ck_tile
::
remove_cvref_t
<
typename
Problem
::
YScaleDataType
>
;
static
constexpr
bool
kNeedCrossWarpSync
=
Problem
::
kNeedCrossWarpSync
;
static
constexpr
bool
kNeedCrossWarpSync
=
Problem
::
kNeedCrossWarpSync
;
static
constexpr
bool
kPadM
=
false
;
//
TODO - BlockSmoothquantProblem::kP
adM
static
constexpr
bool
kPadM
=
false
;
//
No need to p
ad
M
static
constexpr
bool
kPadN
=
Problem
::
kPadN
;
static
constexpr
bool
kPadN
=
Problem
::
kPadN
;
static
constexpr
bool
kSmoothX
=
Problem
::
kSmoothX
;
static
constexpr
bool
UseMax3
=
true
;
// TODO - Move to trait
static
constexpr
bool
UseMax3
=
true
;
// TODO - Move to trait
static
constexpr
const
char
*
name
=
[]()
{
static
constexpr
const
char
*
name
=
[]()
{
...
@@ -76,14 +77,23 @@ struct SmoothquantPipelineTwoPass
...
@@ -76,14 +77,23 @@ struct SmoothquantPipelineTwoPass
for
(
int
iN
=
__builtin_amdgcn_readfirstlane
(
0
);
iN
<
num_n_tile_iteration
;
++
iN
)
for
(
int
iN
=
__builtin_amdgcn_readfirstlane
(
0
);
iN
<
num_n_tile_iteration
;
++
iN
)
{
{
const
auto
x
=
load_tile
(
x_window
);
const
auto
x
=
load_tile
(
x_window
);
const
auto
xscale
=
load_tile
(
xscale_window
);
const
auto
y
=
tile_elementwise_in
(
auto
y
=
[
&
]()
{
[
&
](
const
auto
&
a
,
const
auto
&
b
)
{
if
constexpr
(
kSmoothX
)
return
type_convert
<
ComputeDataType
>
(
a
)
*
type_convert
<
ComputeDataType
>
(
b
);
{
},
const
auto
xscale
=
load_tile
(
xscale_window
);
x
,
return
tile_elementwise_in
(
xscale
);
[
&
](
const
auto
&
a
,
const
auto
&
b
)
{
return
type_convert
<
ComputeDataType
>
(
a
)
*
type_convert
<
ComputeDataType
>
(
b
);
},
x
,
xscale
);
}
else
return
cast_tile
<
ComputeDataType
>
(
x
);
}();
constexpr
auto
x_size_per_row
=
constexpr
auto
x_size_per_row
=
x
.
get_tile_distribution
().
get_ys_to_d_descriptor
().
get_lengths
().
at
(
number
<
1
>
{});
x
.
get_tile_distribution
().
get_ys_to_d_descriptor
().
get_lengths
().
at
(
number
<
1
>
{});
...
@@ -93,8 +103,10 @@ struct SmoothquantPipelineTwoPass
...
@@ -93,8 +103,10 @@ struct SmoothquantPipelineTwoPass
else
else
block_reduce2d
(
y
,
absmax
,
reduce_absmax_func
);
block_reduce2d
(
y
,
absmax
,
reduce_absmax_func
);
if
constexpr
(
kSmoothX
)
move_tile_window
(
xscale_window
,
{
Block_N
});
move_tile_window
(
x_window
,
{
0
,
Block_N
});
move_tile_window
(
x_window
,
{
0
,
Block_N
});
move_tile_window
(
xscale_window
,
{
Block_N
});
}
}
// compute absmax, cross-lane->cross-warp
// compute absmax, cross-lane->cross-warp
...
@@ -113,21 +125,32 @@ struct SmoothquantPipelineTwoPass
...
@@ -113,21 +125,32 @@ struct SmoothquantPipelineTwoPass
ck_tile
::
index_t
stride_to_right_most_window
=
ck_tile
::
index_t
stride_to_right_most_window
=
row_size
%
Block_N
==
0
?
row_size
-
Block_N
:
row_size
-
row_size
%
Block_N
;
row_size
%
Block_N
==
0
?
row_size
-
Block_N
:
row_size
-
row_size
%
Block_N
;
if
constexpr
(
kSmoothX
)
move_tile_window
(
xscale_window
,
{
-
Block_N
});
move_tile_window
(
x_window
,
{
0
,
-
Block_N
});
move_tile_window
(
x_window
,
{
0
,
-
Block_N
});
move_tile_window
(
xscale_window
,
{
-
Block_N
});
move_tile_window
(
qy_window
,
{
0
,
stride_to_right_most_window
});
move_tile_window
(
qy_window
,
{
0
,
stride_to_right_most_window
});
// recompute y and quantize y to qy
// recompute y and quantize y to qy
for
(
int
iN
=
__builtin_amdgcn_readfirstlane
(
0
);
iN
<
num_n_tile_iteration
;
++
iN
)
for
(
int
iN
=
__builtin_amdgcn_readfirstlane
(
0
);
iN
<
num_n_tile_iteration
;
++
iN
)
{
{
const
auto
x
=
load_tile
(
x_window
);
const
auto
x
=
load_tile
(
x_window
);
const
auto
xscale
=
load_tile
(
xscale_window
);
const
auto
y
=
tile_elementwise_in
(
auto
y
=
[
&
]()
{
[
&
](
const
auto
&
a
,
const
auto
&
b
)
{
if
constexpr
(
kSmoothX
)
return
type_convert
<
ComputeDataType
>
(
a
)
*
type_convert
<
ComputeDataType
>
(
b
);
{
},
const
auto
xscale
=
load_tile
(
xscale_window
);
x
,
return
tile_elementwise_in
(
xscale
);
[
&
](
const
auto
&
a
,
const
auto
&
b
)
{
return
type_convert
<
ComputeDataType
>
(
a
)
*
type_convert
<
ComputeDataType
>
(
b
);
},
x
,
xscale
);
}
else
return
cast_tile
<
ComputeDataType
>
(
x
);
}();
auto
qy
=
make_static_distributed_tensor
<
QYDataType
>
(
y
.
get_tile_distribution
());
auto
qy
=
make_static_distributed_tensor
<
QYDataType
>
(
y
.
get_tile_distribution
());
sweep_tile
(
qy
,
[
&
](
auto
idx
)
{
sweep_tile
(
qy
,
[
&
](
auto
idx
)
{
...
@@ -137,8 +160,10 @@ struct SmoothquantPipelineTwoPass
...
@@ -137,8 +160,10 @@ struct SmoothquantPipelineTwoPass
});
});
store_tile
(
qy_window
,
qy
);
store_tile
(
qy_window
,
qy
);
if
constexpr
(
kSmoothX
)
move_tile_window
(
xscale_window
,
{
0
,
-
Block_N
});
move_tile_window
(
x_window
,
{
0
,
-
Block_N
});
move_tile_window
(
x_window
,
{
0
,
-
Block_N
});
move_tile_window
(
xscale_window
,
{
0
,
-
Block_N
});
move_tile_window
(
qy_window
,
{
0
,
-
Block_N
});
move_tile_window
(
qy_window
,
{
0
,
-
Block_N
});
}
}
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment