Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
b9806269
Commit
b9806269
authored
Dec 13, 2024
by
Aleksander Dudek
Browse files
Merge branch 'develop' into ck_tile_gemmkernel_reuse
parents
8385597f
4e731776
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
66 additions
and
42 deletions
+66
-42
example/ck_tile/12_smoothquant/example_smoothquant.cpp
example/ck_tile/12_smoothquant/example_smoothquant.cpp
+26
-18
example/ck_tile/12_smoothquant/smoothquant.cpp
example/ck_tile/12_smoothquant/smoothquant.cpp
+26
-18
include/ck_tile/ops/smoothquant/kernel/smoothquant_kernel.hpp
...ude/ck_tile/ops/smoothquant/kernel/smoothquant_kernel.hpp
+14
-6
No files found.
example/ck_tile/12_smoothquant/example_smoothquant.cpp
View file @
b9806269
...
@@ -35,7 +35,8 @@ auto create_args(int argc, char* argv[])
...
@@ -35,7 +35,8 @@ auto create_args(int argc, char* argv[])
ck_tile
::
ArgParser
arg_parser
;
ck_tile
::
ArgParser
arg_parser
;
arg_parser
.
insert
(
"m"
,
"3328"
,
"m dimension"
)
arg_parser
.
insert
(
"m"
,
"3328"
,
"m dimension"
)
.
insert
(
"n"
,
"4096"
,
"n dimension"
)
.
insert
(
"n"
,
"4096"
,
"n dimension"
)
.
insert
(
"stride"
,
"-1"
,
"stride per row, if -1 then equal to n"
)
.
insert
(
"x_stride"
,
"-1"
,
"input stride per row, if -1 then equal to n"
)
.
insert
(
"y_stride"
,
"-1"
,
"output stride per row, if -1 then equal to n"
)
.
insert
(
"e"
,
"1e-5"
,
"epsilon"
)
.
insert
(
"e"
,
"1e-5"
,
"epsilon"
)
.
insert
(
"v"
,
"1"
,
"cpu validation or not"
)
.
insert
(
"v"
,
"1"
,
"cpu validation or not"
)
.
insert
(
"prec"
,
"fp16"
,
"precision"
)
.
insert
(
"prec"
,
"fp16"
,
"precision"
)
...
@@ -49,11 +50,14 @@ auto create_args(int argc, char* argv[])
...
@@ -49,11 +50,14 @@ auto create_args(int argc, char* argv[])
template
<
typename
DataType
>
template
<
typename
DataType
>
bool
run
(
const
ck_tile
::
ArgParser
&
arg_parser
)
bool
run
(
const
ck_tile
::
ArgParser
&
arg_parser
)
{
{
ck_tile
::
index_t
m
=
arg_parser
.
get_int
(
"m"
);
ck_tile
::
index_t
m
=
arg_parser
.
get_int
(
"m"
);
ck_tile
::
index_t
n
=
arg_parser
.
get_int
(
"n"
);
ck_tile
::
index_t
n
=
arg_parser
.
get_int
(
"n"
);
ck_tile
::
index_t
stride
=
arg_parser
.
get_int
(
"stride"
);
ck_tile
::
index_t
x_stride
=
arg_parser
.
get_int
(
"x_stride"
);
if
(
stride
<
0
)
if
(
x_stride
<
0
)
stride
=
n
;
x_stride
=
n
;
ck_tile
::
index_t
y_stride
=
arg_parser
.
get_int
(
"y_stride"
);
if
(
y_stride
<
0
)
y_stride
=
n
;
std
::
string
data_type
=
arg_parser
.
get_str
(
"prec"
);
std
::
string
data_type
=
arg_parser
.
get_str
(
"prec"
);
int
do_validation
=
arg_parser
.
get_int
(
"v"
);
int
do_validation
=
arg_parser
.
get_int
(
"v"
);
int
warmup
=
arg_parser
.
get_int
(
"warmup"
);
int
warmup
=
arg_parser
.
get_int
(
"warmup"
);
...
@@ -68,14 +72,14 @@ bool run(const ck_tile::ArgParser& arg_parser)
...
@@ -68,14 +72,14 @@ bool run(const ck_tile::ArgParser& arg_parser)
using
ComputeDataType
=
float
;
using
ComputeDataType
=
float
;
// host verify
// host verify
ck_tile
::
HostTensor
<
XDataType
>
x_host
({
m
,
n
},
{
stride
,
1
});
ck_tile
::
HostTensor
<
XDataType
>
x_host
({
m
,
n
},
{
x_
stride
,
1
});
ck_tile
::
HostTensor
<
XScaleDataType
>
xscale_host
({
n
});
ck_tile
::
HostTensor
<
XScaleDataType
>
xscale_host
({
n
});
ck_tile
::
HostTensor
<
YScaleDataType
>
yscale_host_ref
({
m
},
{
1
});
ck_tile
::
HostTensor
<
YScaleDataType
>
yscale_host_ref
({
m
},
{
1
});
ck_tile
::
HostTensor
<
YScaleDataType
>
yscale_host_dev
({
m
},
{
1
});
ck_tile
::
HostTensor
<
YScaleDataType
>
yscale_host_dev
({
m
},
{
1
});
ck_tile
::
HostTensor
<
QYDataType
>
qy_host_ref
({
m
,
n
},
{
stride
,
1
});
ck_tile
::
HostTensor
<
QYDataType
>
qy_host_ref
({
m
,
n
},
{
y_
stride
,
1
});
ck_tile
::
HostTensor
<
QYDataType
>
qy_host_dev
({
m
,
n
},
{
stride
,
1
});
ck_tile
::
HostTensor
<
QYDataType
>
qy_host_dev
({
m
,
n
},
{
y_
stride
,
1
});
ck_tile
::
FillUniformDistribution
<
XDataType
>
{
-
.5
f
,
.5
f
}(
x_host
);
ck_tile
::
FillUniformDistribution
<
XDataType
>
{
-
.5
f
,
.5
f
}(
x_host
);
ck_tile
::
FillUniformDistribution
<
XScaleDataType
>
{
1e-3
,
.5
f
}(
xscale_host
);
ck_tile
::
FillUniformDistribution
<
XScaleDataType
>
{
1e-3
,
.5
f
}(
xscale_host
);
...
@@ -116,7 +120,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
...
@@ -116,7 +120,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
qy_buf
.
GetDeviceBuffer
(),
qy_buf
.
GetDeviceBuffer
(),
m
,
m
,
n
,
n
,
stride
};
x_stride
,
y_stride
};
auto
kargs
=
Kernel
::
MakeKargs
(
args
);
auto
kargs
=
Kernel
::
MakeKargs
(
args
);
...
@@ -133,7 +138,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
...
@@ -133,7 +138,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
if
(
do_validation
)
if
(
do_validation
)
{
{
using
YDataType
=
ComputeDataType
;
using
YDataType
=
ComputeDataType
;
ck_tile
::
HostTensor
<
ComputeDataType
>
y_host
({
m
,
n
},
{
stride
,
1
});
ck_tile
::
HostTensor
<
ComputeDataType
>
y_host
({
m
,
n
},
{
y_
stride
,
1
});
// smooth outlier
// smooth outlier
{
{
auto
f
=
[
&
](
auto
n_
)
{
auto
f
=
[
&
](
auto
n_
)
{
...
@@ -183,7 +188,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
...
@@ -183,7 +188,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
qy_buf
.
FromDevice
(
qy_host_dev
.
data
());
qy_buf
.
FromDevice
(
qy_host_dev
.
data
());
auto
[
rtol
,
atol
]
=
get_elimit
<
QYDataType
>
();
auto
[
rtol
,
atol
]
=
get_elimit
<
QYDataType
>
();
if
(
stride
==
n
)
if
(
y_
stride
==
n
)
{
{
pass
=
ck_tile
::
check_err
(
qy_host_dev
,
pass
=
ck_tile
::
check_err
(
qy_host_dev
,
qy_host_ref
,
qy_host_ref
,
...
@@ -195,10 +200,12 @@ bool run(const ck_tile::ArgParser& arg_parser)
...
@@ -195,10 +200,12 @@ bool run(const ck_tile::ArgParser& arg_parser)
{
{
for
(
int
i_r
=
0
;
i_r
<
m
;
i_r
++
)
for
(
int
i_r
=
0
;
i_r
<
m
;
i_r
++
)
{
{
std
::
vector
<
QYDataType
>
qy_host_dev_row
(
qy_host_dev
.
begin
()
+
i_r
*
stride
,
std
::
vector
<
QYDataType
>
qy_host_dev_row
(
qy_host_dev
.
begin
()
+
i_r
*
y_stride
,
qy_host_dev
.
begin
()
+
i_r
*
stride
+
n
);
qy_host_dev
.
begin
()
+
i_r
*
y_stride
+
std
::
vector
<
QYDataType
>
qy_host_ref_row
(
qy_host_ref
.
begin
()
+
i_r
*
stride
,
n
);
qy_host_ref
.
begin
()
+
i_r
*
stride
+
n
);
std
::
vector
<
QYDataType
>
qy_host_ref_row
(
qy_host_ref
.
begin
()
+
i_r
*
y_stride
,
qy_host_ref
.
begin
()
+
i_r
*
y_stride
+
n
);
pass
&=
ck_tile
::
check_err
(
qy_host_dev_row
,
pass
&=
ck_tile
::
check_err
(
qy_host_dev_row
,
qy_host_ref_row
,
qy_host_ref_row
,
std
::
string
(
"qy["
)
+
std
::
to_string
(
i_r
)
+
std
::
string
(
"qy["
)
+
std
::
to_string
(
i_r
)
+
...
@@ -210,8 +217,9 @@ bool run(const ck_tile::ArgParser& arg_parser)
...
@@ -210,8 +217,9 @@ bool run(const ck_tile::ArgParser& arg_parser)
}
}
std
::
cout
<<
"["
<<
data_type
<<
"]"
std
::
cout
<<
"["
<<
data_type
<<
"]"
<<
" m:"
<<
m
<<
", n:"
<<
n
<<
", stride:"
<<
stride
<<
" m:"
<<
m
<<
", n:"
<<
n
<<
", x_stride:"
<<
x_stride
<<
", valid:"
<<
(
pass
?
"y"
:
"n"
)
<<
std
::
flush
<<
std
::
endl
;
<<
", y_stride:"
<<
y_stride
<<
", valid:"
<<
(
pass
?
"y"
:
"n"
)
<<
std
::
flush
<<
std
::
endl
;
}
}
return
pass
;
return
pass
;
...
...
example/ck_tile/12_smoothquant/smoothquant.cpp
View file @
b9806269
...
@@ -33,7 +33,8 @@ auto create_args(int argc, char* argv[])
...
@@ -33,7 +33,8 @@ auto create_args(int argc, char* argv[])
ck_tile
::
ArgParser
arg_parser
;
ck_tile
::
ArgParser
arg_parser
;
arg_parser
.
insert
(
"m"
,
"3328"
,
"m dimension"
)
arg_parser
.
insert
(
"m"
,
"3328"
,
"m dimension"
)
.
insert
(
"n"
,
"4096"
,
"n dimension"
)
.
insert
(
"n"
,
"4096"
,
"n dimension"
)
.
insert
(
"stride"
,
"-1"
,
"stride per row, if -1 then equal to n"
)
.
insert
(
"x_stride"
,
"-1"
,
"input stride per row, if -1 then equal to n"
)
.
insert
(
"y_stride"
,
"-1"
,
"output stride per row, if -1 then equal to n"
)
.
insert
(
"v"
,
"1"
,
"cpu validation or not"
)
.
insert
(
"v"
,
"1"
,
"cpu validation or not"
)
.
insert
(
"kname"
,
"1"
,
"print kernel name or not"
)
.
insert
(
"kname"
,
"1"
,
"print kernel name or not"
)
.
insert
(
"prec"
,
"fp16"
,
"precision"
)
.
insert
(
"prec"
,
"fp16"
,
"precision"
)
...
@@ -47,18 +48,21 @@ auto create_args(int argc, char* argv[])
...
@@ -47,18 +48,21 @@ auto create_args(int argc, char* argv[])
template
<
typename
DataType
>
template
<
typename
DataType
>
bool
run
(
const
ck_tile
::
ArgParser
&
arg_parser
)
bool
run
(
const
ck_tile
::
ArgParser
&
arg_parser
)
{
{
ck_tile
::
index_t
m
=
arg_parser
.
get_int
(
"m"
);
ck_tile
::
index_t
m
=
arg_parser
.
get_int
(
"m"
);
ck_tile
::
index_t
n
=
arg_parser
.
get_int
(
"n"
);
ck_tile
::
index_t
n
=
arg_parser
.
get_int
(
"n"
);
ck_tile
::
index_t
stride
=
arg_parser
.
get_int
(
"stride"
);
ck_tile
::
index_t
x_stride
=
arg_parser
.
get_int
(
"x_stride"
);
if
(
stride
<
0
)
if
(
x_stride
<
0
)
stride
=
n
;
x_stride
=
n
;
ck_tile
::
index_t
y_stride
=
arg_parser
.
get_int
(
"y_stride"
);
if
(
y_stride
<
0
)
y_stride
=
n
;
std
::
string
data_type
=
arg_parser
.
get_str
(
"prec"
);
std
::
string
data_type
=
arg_parser
.
get_str
(
"prec"
);
int
kname
=
arg_parser
.
get_int
(
"kname"
);
int
kname
=
arg_parser
.
get_int
(
"kname"
);
int
do_validation
=
arg_parser
.
get_int
(
"v"
);
int
do_validation
=
arg_parser
.
get_int
(
"v"
);
int
warmup
=
arg_parser
.
get_int
(
"warmup"
);
int
warmup
=
arg_parser
.
get_int
(
"warmup"
);
int
repeat
=
arg_parser
.
get_int
(
"repeat"
);
int
repeat
=
arg_parser
.
get_int
(
"repeat"
);
assert
(
stride
>=
n
);
assert
(
x_
stride
>=
n
);
using
TypeConfig
=
SmoothquantTypeConfig
<
DataType
>
;
using
TypeConfig
=
SmoothquantTypeConfig
<
DataType
>
;
...
@@ -69,14 +73,14 @@ bool run(const ck_tile::ArgParser& arg_parser)
...
@@ -69,14 +73,14 @@ bool run(const ck_tile::ArgParser& arg_parser)
using
ComputeDataType
=
typename
TypeConfig
::
ComputeDataType
;
using
ComputeDataType
=
typename
TypeConfig
::
ComputeDataType
;
// host verify
// host verify
ck_tile
::
HostTensor
<
XDataType
>
x_host
({
m
,
n
},
{
stride
,
1
});
ck_tile
::
HostTensor
<
XDataType
>
x_host
({
m
,
n
},
{
x_
stride
,
1
});
ck_tile
::
HostTensor
<
XScaleDataType
>
xscale_host
({
n
});
ck_tile
::
HostTensor
<
XScaleDataType
>
xscale_host
({
n
});
ck_tile
::
HostTensor
<
YScaleDataType
>
yscale_host_ref
({
m
},
{
1
});
ck_tile
::
HostTensor
<
YScaleDataType
>
yscale_host_ref
({
m
},
{
1
});
ck_tile
::
HostTensor
<
YScaleDataType
>
yscale_host_dev
({
m
},
{
1
});
ck_tile
::
HostTensor
<
YScaleDataType
>
yscale_host_dev
({
m
},
{
1
});
ck_tile
::
HostTensor
<
QYDataType
>
qy_host_ref
({
m
,
n
},
{
stride
,
1
});
ck_tile
::
HostTensor
<
QYDataType
>
qy_host_ref
({
m
,
n
},
{
y_
stride
,
1
});
ck_tile
::
HostTensor
<
QYDataType
>
qy_host_dev
({
m
,
n
},
{
stride
,
1
});
ck_tile
::
HostTensor
<
QYDataType
>
qy_host_dev
({
m
,
n
},
{
y_
stride
,
1
});
ck_tile
::
FillUniformDistribution
<
XDataType
>
{
-
.5
f
,
.5
f
}(
x_host
);
ck_tile
::
FillUniformDistribution
<
XDataType
>
{
-
.5
f
,
.5
f
}(
x_host
);
ck_tile
::
FillUniformDistribution
<
XScaleDataType
>
{
1e-3
,
.5
f
}(
xscale_host
);
ck_tile
::
FillUniformDistribution
<
XScaleDataType
>
{
1e-3
,
.5
f
}(
xscale_host
);
...
@@ -90,7 +94,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
...
@@ -90,7 +94,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
xscale_buf
.
ToDevice
(
xscale_host
.
data
());
xscale_buf
.
ToDevice
(
xscale_host
.
data
());
std
::
cout
<<
"["
<<
data_type
<<
"]"
std
::
cout
<<
"["
<<
data_type
<<
"]"
<<
" m:"
<<
m
<<
", n:"
<<
n
<<
", stride:"
<<
stride
<<
std
::
flush
;
<<
" m:"
<<
m
<<
", n:"
<<
n
<<
", x_stride:"
<<
x_stride
<<
", y_stride:"
<<
y_stride
<<
std
::
flush
;
smoothquant_traits
traits
{
data_type
};
smoothquant_traits
traits
{
data_type
};
...
@@ -100,7 +105,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
...
@@ -100,7 +105,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
qy_buf
.
GetDeviceBuffer
(),
qy_buf
.
GetDeviceBuffer
(),
m
,
m
,
n
,
n
,
stride
};
x_stride
,
y_stride
};
float
ave_time
=
smoothquant
(
float
ave_time
=
smoothquant
(
traits
,
args
,
ck_tile
::
stream_config
{
nullptr
,
true
,
kname
?
1
:
0
,
warmup
,
repeat
});
traits
,
args
,
ck_tile
::
stream_config
{
nullptr
,
true
,
kname
?
1
:
0
,
warmup
,
repeat
});
...
@@ -116,7 +122,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
...
@@ -116,7 +122,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
if
(
do_validation
)
if
(
do_validation
)
{
{
using
YDataType
=
ComputeDataType
;
using
YDataType
=
ComputeDataType
;
ck_tile
::
HostTensor
<
ComputeDataType
>
y_host
({
m
,
n
},
{
stride
,
1
});
ck_tile
::
HostTensor
<
ComputeDataType
>
y_host
({
m
,
n
},
{
y_
stride
,
1
});
// smooth outlier
// smooth outlier
{
{
auto
f
=
[
&
](
auto
n_
)
{
auto
f
=
[
&
](
auto
n_
)
{
...
@@ -166,7 +172,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
...
@@ -166,7 +172,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
qy_buf
.
FromDevice
(
qy_host_dev
.
data
());
qy_buf
.
FromDevice
(
qy_host_dev
.
data
());
auto
[
rtol
,
atol
]
=
get_elimit
<
QYDataType
>
();
auto
[
rtol
,
atol
]
=
get_elimit
<
QYDataType
>
();
if
(
stride
==
n
)
if
(
y_
stride
==
n
)
{
{
pass
=
ck_tile
::
check_err
(
qy_host_dev
,
pass
=
ck_tile
::
check_err
(
qy_host_dev
,
qy_host_ref
,
qy_host_ref
,
...
@@ -178,10 +184,12 @@ bool run(const ck_tile::ArgParser& arg_parser)
...
@@ -178,10 +184,12 @@ bool run(const ck_tile::ArgParser& arg_parser)
{
{
for
(
int
i_r
=
0
;
i_r
<
m
;
i_r
++
)
for
(
int
i_r
=
0
;
i_r
<
m
;
i_r
++
)
{
{
std
::
vector
<
QYDataType
>
qy_host_dev_row
(
qy_host_dev
.
begin
()
+
i_r
*
stride
,
std
::
vector
<
QYDataType
>
qy_host_dev_row
(
qy_host_dev
.
begin
()
+
i_r
*
y_stride
,
qy_host_dev
.
begin
()
+
i_r
*
stride
+
n
);
qy_host_dev
.
begin
()
+
i_r
*
y_stride
+
std
::
vector
<
QYDataType
>
qy_host_ref_row
(
qy_host_ref
.
begin
()
+
i_r
*
stride
,
n
);
qy_host_ref
.
begin
()
+
i_r
*
stride
+
n
);
std
::
vector
<
QYDataType
>
qy_host_ref_row
(
qy_host_ref
.
begin
()
+
i_r
*
y_stride
,
qy_host_ref
.
begin
()
+
i_r
*
y_stride
+
n
);
pass
&=
ck_tile
::
check_err
(
qy_host_dev_row
,
pass
&=
ck_tile
::
check_err
(
qy_host_dev_row
,
qy_host_ref_row
,
qy_host_ref_row
,
std
::
string
(
"qy["
)
+
std
::
to_string
(
i_r
)
+
std
::
string
(
"qy["
)
+
std
::
to_string
(
i_r
)
+
...
...
include/ck_tile/ops/smoothquant/kernel/smoothquant_kernel.hpp
View file @
b9806269
...
@@ -19,7 +19,8 @@ struct SmoothquantHostArgs
...
@@ -19,7 +19,8 @@ struct SmoothquantHostArgs
index_t
m
;
index_t
m
;
index_t
n
;
index_t
n
;
index_t
stride
;
// row_stride
index_t
x_stride
;
// input row_stride
index_t
y_stride
;
// output row_stride
};
};
// TODO: Extract some type to wrapper class
// TODO: Extract some type to wrapper class
...
@@ -58,14 +59,21 @@ struct Smoothquant
...
@@ -58,14 +59,21 @@ struct Smoothquant
index_t
m
;
index_t
m
;
index_t
n
;
index_t
n
;
index_t
stride
;
// row_stride
index_t
x_stride
;
// input row_stride
index_t
y_stride
;
// out row_stride
};
};
using
Hargs
=
SmoothquantHostArgs
;
using
Hargs
=
SmoothquantHostArgs
;
CK_TILE_HOST
static
constexpr
Kargs
MakeKargs
(
const
Hargs
&
hargs
)
CK_TILE_HOST
static
constexpr
Kargs
MakeKargs
(
const
Hargs
&
hargs
)
{
{
return
Kargs
{
return
Kargs
{
hargs
.
p_x
,
hargs
.
p_x
,
hargs
.
p_xscale
,
hargs
.
p_yscale
,
hargs
.
p_qy
,
hargs
.
m
,
hargs
.
n
,
hargs
.
stride
};
hargs
.
p_xscale
,
hargs
.
p_yscale
,
hargs
.
p_qy
,
hargs
.
m
,
hargs
.
n
,
hargs
.
x_stride
,
hargs
.
y_stride
};
}
}
CK_TILE_HOST
static
constexpr
auto
GridSize
(
const
Hargs
&
hargs
)
CK_TILE_HOST
static
constexpr
auto
GridSize
(
const
Hargs
&
hargs
)
...
@@ -116,7 +124,7 @@ struct Smoothquant
...
@@ -116,7 +124,7 @@ struct Smoothquant
const
auto
tmp_
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
const
auto
tmp_
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
static_cast
<
const
XDataType
*>
(
kargs
.
p_x
),
static_cast
<
const
XDataType
*>
(
kargs
.
p_x
),
make_tuple
(
kargs
.
m
,
kargs
.
n
),
make_tuple
(
kargs
.
m
,
kargs
.
n
),
make_tuple
(
kargs
.
stride
,
1
),
make_tuple
(
kargs
.
x_
stride
,
1
),
number
<
Vector_N
>
{},
number
<
Vector_N
>
{},
number
<
1
>
{});
number
<
1
>
{});
...
@@ -157,7 +165,7 @@ struct Smoothquant
...
@@ -157,7 +165,7 @@ struct Smoothquant
auto
tmp_
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
auto
tmp_
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
static_cast
<
QYDataType
*>
(
kargs
.
p_qy
),
static_cast
<
QYDataType
*>
(
kargs
.
p_qy
),
make_tuple
(
kargs
.
m
,
kargs
.
n
),
make_tuple
(
kargs
.
m
,
kargs
.
n
),
make_tuple
(
kargs
.
stride
,
1
),
make_tuple
(
kargs
.
y_
stride
,
1
),
number
<
Vector_N
>
{},
number
<
Vector_N
>
{},
number
<
1
>
{});
number
<
1
>
{});
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment