Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
a7ae4f8e
Commit
a7ae4f8e
authored
Jan 27, 2025
by
Astha Rai
Browse files
Merge branch 'codegen_hiprtc' of github.com:ROCm/composable_kernel into codegen_hiprtc
parents
a6055c3c
781005a5
Changes
175
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
543 additions
and
245 deletions
+543
-245
include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_one_pass.hpp
...ayernorm2d/pipeline/layernorm2d_fwd_pipeline_one_pass.hpp
+4
-4
include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_problem.hpp
...layernorm2d/pipeline/layernorm2d_fwd_pipeline_problem.hpp
+13
-13
include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_two_pass.hpp
...ayernorm2d/pipeline/layernorm2d_fwd_pipeline_two_pass.hpp
+3
-3
include/ck_tile/ops/rmsnorm2d.hpp
include/ck_tile/ops/rmsnorm2d.hpp
+1
-0
include/ck_tile/ops/rmsnorm2d/kernel/rmsnorm2d_fwd_kernel.hpp
...ude/ck_tile/ops/rmsnorm2d/kernel/rmsnorm2d_fwd_kernel.hpp
+169
-28
include/ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_default_policy.hpp
...norm2d/pipeline/rmsnorm2d_fwd_pipeline_default_policy.hpp
+5
-5
include/ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_one_pass.hpp
...ps/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_one_pass.hpp
+66
-15
include/ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_problem.hpp
...ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_problem.hpp
+13
-13
include/ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_two_pass.hpp
...ps/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_two_pass.hpp
+73
-18
include/ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_traits.hpp
...e/ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_traits.hpp
+54
-0
include/ck_tile/ops/smoothquant/kernel/moe_smoothquant_kernel.hpp
...ck_tile/ops/smoothquant/kernel/moe_smoothquant_kernel.hpp
+15
-13
include/ck_tile/ops/smoothquant/kernel/smoothquant_kernel.hpp
...ude/ck_tile/ops/smoothquant/kernel/smoothquant_kernel.hpp
+15
-15
include/ck_tile/ops/smoothquant/pipeline/smoothquant_pipeline_default_policy.hpp
...othquant/pipeline/smoothquant_pipeline_default_policy.hpp
+2
-2
include/ck_tile/ops/smoothquant/pipeline/smoothquant_pipeline_one_pass.hpp
...ps/smoothquant/pipeline/smoothquant_pipeline_one_pass.hpp
+18
-15
include/ck_tile/ops/smoothquant/pipeline/smoothquant_pipeline_problem.hpp
...ops/smoothquant/pipeline/smoothquant_pipeline_problem.hpp
+9
-9
include/ck_tile/ops/smoothquant/pipeline/smoothquant_pipeline_two_pass.hpp
...ps/smoothquant/pipeline/smoothquant_pipeline_two_pass.hpp
+25
-22
library/include/ck/library/reference_tensor_operation/cpu/reference_fpAintB_gemm.hpp
...reference_tensor_operation/cpu/reference_fpAintB_gemm.hpp
+4
-34
library/include/ck/library/reference_tensor_operation/cpu/reference_gemm.hpp
...library/reference_tensor_operation/cpu/reference_gemm.hpp
+4
-15
library/include/ck/library/reference_tensor_operation/cpu/reference_gemm_multiple_d.hpp
...erence_tensor_operation/cpu/reference_gemm_multiple_d.hpp
+3
-21
library/include/ck/library/tensor_operation_instance/gpu/grouped_gemm_fixed_nk.hpp
...y/tensor_operation_instance/gpu/grouped_gemm_fixed_nk.hpp
+47
-0
No files found.
include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_one_pass.hpp
View file @
a7ae4f8e
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
5
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
...
@@ -64,7 +64,7 @@ struct Layernorm2dFwdPipelineOnePass
...
@@ -64,7 +64,7 @@ struct Layernorm2dFwdPipelineOnePass
typename
YResidualWindow
,
typename
YResidualWindow
,
typename
MeanWindow
,
typename
MeanWindow
,
typename
InvStdWindow
,
typename
InvStdWindow
,
typename
X
ScaleWindow
,
typename
Smooth
ScaleWindow
,
typename
YScaleWindow
,
typename
YScaleWindow
,
typename
Epilogue
>
typename
Epilogue
>
CK_TILE_DEVICE
auto
operator
()(
const
XWindow
&
x_window_
,
CK_TILE_DEVICE
auto
operator
()(
const
XWindow
&
x_window_
,
...
@@ -76,7 +76,7 @@ struct Layernorm2dFwdPipelineOnePass
...
@@ -76,7 +76,7 @@ struct Layernorm2dFwdPipelineOnePass
const
YResidualWindow
&
y_residual_window_
,
const
YResidualWindow
&
y_residual_window_
,
MeanWindow
&
mean_window
,
MeanWindow
&
mean_window
,
InvStdWindow
&
inv_std_window
,
InvStdWindow
&
inv_std_window
,
const
X
ScaleWindow
&
x
_scale_window_
,
const
Smooth
ScaleWindow
&
sm
_scale_window_
,
YScaleWindow
&
y_scale_window
,
YScaleWindow
&
y_scale_window
,
ComputeDataType
epsilon
,
ComputeDataType
epsilon
,
ck_tile
::
index_t
row_size
,
ck_tile
::
index_t
row_size
,
...
@@ -190,7 +190,7 @@ struct Layernorm2dFwdPipelineOnePass
...
@@ -190,7 +190,7 @@ struct Layernorm2dFwdPipelineOnePass
if
constexpr
(
kFusedQuant
==
Layernorm2dFusedQuantEnum
::
DYNAMIC_QUANT
||
if
constexpr
(
kFusedQuant
==
Layernorm2dFusedQuantEnum
::
DYNAMIC_QUANT
||
kFusedQuant
==
Layernorm2dFusedQuantEnum
::
SMOOTH_DYNAMIC_QUANT
)
kFusedQuant
==
Layernorm2dFusedQuantEnum
::
SMOOTH_DYNAMIC_QUANT
)
{
{
Epilogue
{}(
y_window_
,
x
_scale_window_
,
y_scale_window
,
ln
,
smem
);
Epilogue
{}(
y_window_
,
sm
_scale_window_
,
y_scale_window
,
ln
,
smem
);
}
}
else
else
Epilogue
{}(
y_window_
,
ln
);
Epilogue
{}(
y_window_
,
ln
);
...
...
include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_problem.hpp
View file @
a7ae4f8e
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
5
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
...
@@ -15,23 +15,23 @@ template <typename XDataType_,
...
@@ -15,23 +15,23 @@ template <typename XDataType_,
typename
YDataType_
,
typename
YDataType_
,
typename
MeanDataType_
,
typename
MeanDataType_
,
typename
InvStdDataType_
,
typename
InvStdDataType_
,
typename
X
ScaleDataType_
,
typename
Smooth
ScaleDataType_
,
typename
YScaleDataType_
,
typename
YScaleDataType_
,
typename
BlockShape_
,
typename
BlockShape_
,
typename
Traits_
>
typename
Traits_
>
struct
Layernorm2dFwdPipelineProblem
struct
Layernorm2dFwdPipelineProblem
{
{
using
XDataType
=
remove_cvref_t
<
XDataType_
>
;
using
XDataType
=
remove_cvref_t
<
XDataType_
>
;
using
XBiasDataType
=
remove_cvref_t
<
XBiasDataType_
>
;
using
XBiasDataType
=
remove_cvref_t
<
XBiasDataType_
>
;
using
GammaDataType
=
remove_cvref_t
<
GammaDataType_
>
;
using
GammaDataType
=
remove_cvref_t
<
GammaDataType_
>
;
using
BetaDataType
=
remove_cvref_t
<
BetaDataType_
>
;
using
BetaDataType
=
remove_cvref_t
<
BetaDataType_
>
;
using
ComputeDataType
=
remove_cvref_t
<
ComputeDataType_
>
;
using
ComputeDataType
=
remove_cvref_t
<
ComputeDataType_
>
;
using
YDataType
=
remove_cvref_t
<
YDataType_
>
;
using
YDataType
=
remove_cvref_t
<
YDataType_
>
;
using
MeanDataType
=
remove_cvref_t
<
MeanDataType_
>
;
using
MeanDataType
=
remove_cvref_t
<
MeanDataType_
>
;
using
InvStdDataType
=
remove_cvref_t
<
InvStdDataType_
>
;
using
InvStdDataType
=
remove_cvref_t
<
InvStdDataType_
>
;
using
X
ScaleDataType
=
remove_cvref_t
<
X
ScaleDataType_
>
;
using
Smooth
ScaleDataType
=
remove_cvref_t
<
Smooth
ScaleDataType_
>
;
using
YScaleDataType
=
remove_cvref_t
<
YScaleDataType_
>
;
using
YScaleDataType
=
remove_cvref_t
<
YScaleDataType_
>
;
using
BlockShape
=
remove_cvref_t
<
BlockShape_
>
;
using
BlockShape
=
remove_cvref_t
<
BlockShape_
>
;
static
constexpr
bool
kNeedCrossLaneSync
=
BlockShape
::
ThreadPerWarp_N
>
1
;
static
constexpr
bool
kNeedCrossLaneSync
=
BlockShape
::
ThreadPerWarp_N
>
1
;
static
constexpr
bool
kNeedCrossWarpSync
=
BlockShape
::
WarpPerBlock_N
>
1
;
static
constexpr
bool
kNeedCrossWarpSync
=
BlockShape
::
WarpPerBlock_N
>
1
;
...
...
include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_two_pass.hpp
View file @
a7ae4f8e
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
5
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
...
@@ -63,7 +63,7 @@ struct Layernorm2dFwdPipelineTwoPass
...
@@ -63,7 +63,7 @@ struct Layernorm2dFwdPipelineTwoPass
typename
YResidualWindow
,
typename
YResidualWindow
,
typename
MeanWindow
,
typename
MeanWindow
,
typename
InvStdWindow
,
typename
InvStdWindow
,
typename
X
ScaleWindow
,
typename
Smooth
ScaleWindow
,
typename
YScaleWindow
,
typename
YScaleWindow
,
typename
Epilogue
>
typename
Epilogue
>
CK_TILE_DEVICE
auto
operator
()(
const
XWindow
&
x_window_
,
CK_TILE_DEVICE
auto
operator
()(
const
XWindow
&
x_window_
,
...
@@ -75,7 +75,7 @@ struct Layernorm2dFwdPipelineTwoPass
...
@@ -75,7 +75,7 @@ struct Layernorm2dFwdPipelineTwoPass
const
YResidualWindow
&
y_residual_window_
,
const
YResidualWindow
&
y_residual_window_
,
MeanWindow
&
mean_window
,
MeanWindow
&
mean_window
,
InvStdWindow
&
inv_std_window
,
InvStdWindow
&
inv_std_window
,
const
X
ScaleWindow
&
/*
x
_scale_window*/
,
const
Smooth
ScaleWindow
&
/*
sm
_scale_window*/
,
YScaleWindow
&
/*y_scale_window*/
,
YScaleWindow
&
/*y_scale_window*/
,
ComputeDataType
epsilon
,
ComputeDataType
epsilon
,
ck_tile
::
index_t
row_size
,
ck_tile
::
index_t
row_size
,
...
...
include/ck_tile/ops/rmsnorm2d.hpp
View file @
a7ae4f8e
...
@@ -8,5 +8,6 @@
...
@@ -8,5 +8,6 @@
#include "ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_one_pass.hpp"
#include "ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_one_pass.hpp"
#include "ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_problem.hpp"
#include "ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_problem.hpp"
#include "ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_two_pass.hpp"
#include "ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_two_pass.hpp"
#include "ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_traits.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
include/ck_tile/ops/rmsnorm2d/kernel/rmsnorm2d_fwd_kernel.hpp
View file @
a7ae4f8e
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
5
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common.hpp"
#include "ck_tile/ops/common.hpp"
#include "ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_traits.hpp"
namespace
ck_tile
{
namespace
ck_tile
{
// host side args
// host side args
struct
Rmsnorm2dFwdHostArgs
struct
Rmsnorm2dFwdHostArgs
{
{
const
void
*
p_x
;
// [m ,n], input, fp16/bf16
const
void
*
p_x
;
// [m ,n], input, fp16/bf16
const
void
*
p_gamma
;
// [1, n], gamma, prec same as input
const
void
*
p_x_residual
;
// [m ,n], shortcut input, prec same as input, nullptr if not used
const
void
*
p_sm_scale
;
// [1 ,n], smooth scale input, fp32, nullptr if not used
const
void
*
p_gamma
;
// [1, n], gamma, prec same as input
void
*
p_y
;
// [m, n], output, fp16/bf16
void
*
p_y
;
// [m, n], output, fp16/bf16
void
*
p_invRms
;
// [m, 1], output inv-rms, prec same as input, nullptr if not used
void
*
p_y_residual
;
// [m, n], shortcut output, prec same as input, nullptr if not used
void
*
p_y_scale
;
// [m, 1], output a dynamic quant per row, nullptr if not used
void
*
p_invRms
;
// [m, 1], output inv-rms, prec same as input, nullptr if not used
float
epsilon
;
float
epsilon
;
index_t
m
;
index_t
m
;
index_t
n
;
index_t
n
;
index_t
stride
;
// row_stride
index_t
x_stride
;
// x row_stride
index_t
xr_stride
;
// x residule row stride
index_t
y_stride
;
// y row stride
index_t
yr_stride
;
// y residule row stride
};
};
// TODO: Extract some type to wrapper class
// TODO: Extract some type to wrapper class
template
<
typename
Pipeline_
>
template
<
typename
Pipeline_
,
typename
Epilogue_
>
struct
Rmsnorm2dFwd
struct
Rmsnorm2dFwd
{
{
using
Pipeline
=
remove_cvref_t
<
Pipeline_
>
;
using
Pipeline
=
remove_cvref_t
<
Pipeline_
>
;
using
Epilogue
=
remove_cvref_t
<
Epilogue_
>
;
using
Problem
=
typename
Pipeline
::
Problem
;
using
Problem
=
typename
Pipeline
::
Problem
;
using
XDataType
=
remove_cvref_t
<
typename
Problem
::
XDataType
>
;
using
XDataType
=
remove_cvref_t
<
typename
Problem
::
XDataType
>
;
using
GammaDataType
=
remove_cvref_t
<
typename
Problem
::
GammaDataType
>
;
using
GammaDataType
=
remove_cvref_t
<
typename
Problem
::
GammaDataType
>
;
using
ComputeDataType
=
remove_cvref_t
<
typename
Problem
::
ComputeDataType
>
;
using
ComputeDataType
=
remove_cvref_t
<
typename
Problem
::
ComputeDataType
>
;
using
YDataType
=
remove_cvref_t
<
typename
Problem
::
YDataType
>
;
using
YDataType
=
remove_cvref_t
<
typename
Problem
::
YDataType
>
;
using
InvRmsDataType
=
remove_cvref_t
<
typename
Problem
::
InvRmsDataType
>
;
using
InvRmsDataType
=
remove_cvref_t
<
typename
Problem
::
InvRmsDataType
>
;
using
SmoothScaleDataType
=
remove_cvref_t
<
typename
Problem
::
SmoothScaleDataType
>
;
using
YScaleDataType
=
remove_cvref_t
<
typename
Problem
::
YScaleDataType
>
;
// for simplicity, shortcut input/output type is same as X
using
XResidualDataType
=
XDataType
;
using
YResidualDataType
=
XDataType
;
static
constexpr
bool
kHasGamma
=
!
std
::
is_same_v
<
GammaDataType
,
null_type
>
;
static
constexpr
bool
kHasGamma
=
!
std
::
is_same_v
<
GammaDataType
,
null_type
>
;
static
constexpr
bool
kSaveInvRms
=
Problem
::
kSaveInvRms
;
static
constexpr
bool
kSaveInvRms
=
Problem
::
Traits
::
kSaveInvRms
;
static
constexpr
index_t
Block_M
=
Problem
::
BlockShape
::
Block_M
;
static
constexpr
index_t
Block_M
=
Problem
::
BlockShape
::
Block_M
;
static
constexpr
index_t
Block_N
=
Problem
::
BlockShape
::
Block_N
;
static
constexpr
index_t
Block_N
=
Problem
::
BlockShape
::
Block_N
;
static
constexpr
bool
kPadM
=
false
;
// always no need to pad along M
static
constexpr
bool
kPadM
=
false
;
// always no need to pad along M
static
constexpr
bool
kPadN
=
Problem
::
kPadN
;
static
constexpr
bool
kPadN
=
Problem
::
Traits
::
kPadN
;
static
constexpr
bool
kTwoPass
=
Problem
::
kTwoPass
;
static
constexpr
bool
kTwoPass
=
Problem
::
Traits
::
kTwoPass
;
static
constexpr
auto
kFusedAdd
=
Problem
::
Traits
::
kFusedAdd
;
static
constexpr
auto
kFusedQuant
=
Problem
::
Traits
::
kFusedQuant
;
static
constexpr
index_t
ThreadPerWarp_N
=
Problem
::
BlockShape
::
ThreadPerWarp_N
;
static
constexpr
index_t
ThreadPerWarp_N
=
Problem
::
BlockShape
::
ThreadPerWarp_N
;
static
constexpr
index_t
Vector_N
=
Problem
::
BlockShape
::
Vector_N
;
static
constexpr
index_t
Vector_N
=
Problem
::
BlockShape
::
Vector_N
;
...
@@ -56,29 +73,43 @@ struct Rmsnorm2dFwd
...
@@ -56,29 +73,43 @@ struct Rmsnorm2dFwd
struct
Kargs
struct
Kargs
{
{
const
void
*
p_x
;
const
void
*
p_x
;
const
void
*
p_x_residual
;
const
void
*
p_sm_scale
;
const
void
*
p_gamma
;
const
void
*
p_gamma
;
void
*
p_y
;
void
*
p_y
;
void
*
p_y_residual
;
void
*
p_y_scale
;
void
*
p_invRms
;
void
*
p_invRms
;
float
epsilon
;
float
epsilon
;
index_t
m
;
index_t
m
;
index_t
n
;
index_t
n
;
index_t
stride
;
// row_stride
index_t
x_stride
;
// x row_stride
index_t
xr_stride
;
// x residule row stride
index_t
y_stride
;
// y row stride
index_t
yr_stride
;
// y residule row stride
};
};
using
Hargs
=
Rmsnorm2dFwdHostArgs
;
using
Hargs
=
Rmsnorm2dFwdHostArgs
;
CK_TILE_HOST
static
constexpr
Kargs
MakeKargs
(
const
Hargs
&
hargs
)
CK_TILE_HOST
static
constexpr
Kargs
MakeKargs
(
const
Hargs
&
hargs
)
{
{
return
Kargs
{
hargs
.
p_x
,
return
Kargs
{
hargs
.
p_x
,
hargs
.
p_x_residual
,
hargs
.
p_sm_scale
,
hargs
.
p_gamma
,
hargs
.
p_gamma
,
hargs
.
p_y
,
hargs
.
p_y
,
hargs
.
p_y_residual
,
hargs
.
p_y_scale
,
hargs
.
p_invRms
,
hargs
.
p_invRms
,
hargs
.
epsilon
,
hargs
.
epsilon
,
hargs
.
m
,
hargs
.
m
,
hargs
.
n
,
hargs
.
n
,
hargs
.
stride
};
hargs
.
x_stride
,
hargs
.
xr_stride
,
hargs
.
y_stride
,
hargs
.
yr_stride
};
}
}
CK_TILE_HOST
static
constexpr
auto
GridSize
(
const
Hargs
&
hargs
)
CK_TILE_HOST
static
constexpr
auto
GridSize
(
const
Hargs
&
hargs
)
...
@@ -95,6 +126,7 @@ struct Rmsnorm2dFwd
...
@@ -95,6 +126,7 @@ struct Rmsnorm2dFwd
template
<
>
struct
t2s
<
ck_tile
::
bf16_t
>
{
static
constexpr
const
char
*
name
=
"bf16"
;
};
template
<
>
struct
t2s
<
ck_tile
::
bf16_t
>
{
static
constexpr
const
char
*
name
=
"bf16"
;
};
template
<
>
struct
t2s
<
ck_tile
::
fp8_t
>
{
static
constexpr
const
char
*
name
=
"fp8"
;
};
template
<
>
struct
t2s
<
ck_tile
::
fp8_t
>
{
static
constexpr
const
char
*
name
=
"fp8"
;
};
template
<
>
struct
t2s
<
ck_tile
::
bf8_t
>
{
static
constexpr
const
char
*
name
=
"bf8"
;
};
template
<
>
struct
t2s
<
ck_tile
::
bf8_t
>
{
static
constexpr
const
char
*
name
=
"bf8"
;
};
template
<
>
struct
t2s
<
ck_tile
::
int8_t
>
{
static
constexpr
const
char
*
name
=
"int8"
;
};
// clang-format on
// clang-format on
// in byte
// in byte
...
@@ -102,24 +134,41 @@ struct Rmsnorm2dFwd
...
@@ -102,24 +134,41 @@ struct Rmsnorm2dFwd
CK_TILE_HOST
static
std
::
string
GetName
()
CK_TILE_HOST
static
std
::
string
GetName
()
{
{
#define _SS_ std::string
#define _TS_ std::to_string
// clang-format off
// clang-format off
using
S_
=
typename
Problem
::
BlockShape
;
using
S_
=
typename
Problem
::
BlockShape
;
auto
surfix
=
[
&
]
()
{
auto
surfix
=
[
&
]
()
{
std
::
string
n
;
std
::
string
n
;
if
(
kFusedAdd
!=
Rmsnorm2dFusedAddEnum
::
NO_ADD
)
n
+=
_SS_
(
"_"
)
+
Rmsnorm2dFusedAddEnumName
<
kFusedAdd
>::
name
;
if
(
kFusedQuant
!=
Rmsnorm2dFusedQuantEnum
::
NO_SWEEP
)
n
+=
_SS_
(
"_"
)
+
Rmsnorm2dFusedQuantEnumName
<
kFusedQuant
>::
name
;
if
(
kPadN
)
n
+=
"_pn"
;
if
(
kPadN
)
n
+=
"_pn"
;
if
(
kSaveInvRms
)
n
+=
"_rms"
;
if
(
kSaveInvRms
)
n
+=
"_rms"
;
if
(
kTwoPass
)
n
+=
"_2p"
;
if
(
kTwoPass
)
n
+=
"_2p"
;
return
n
;
}();
return
n
;
}();
#define _SS_ std::string
auto
prec_str
=
[
&
]
()
{
#define _TS_ std::to_string
std
::
string
base_str
=
_SS_
(
t2s
<
XDataType
>::
name
);
return
_SS_
(
"rmsnorm2d_fwd_"
)
+
_SS_
(
t2s
<
XDataType
>::
name
)
+
"_"
+
if
(
!
std
::
is_same_v
<
XDataType
,
YDataType
>
)
{
base_str
+=
_SS_
(
"_"
)
+
_SS_
(
t2s
<
YDataType
>::
name
);
}
if
(
kFusedQuant
==
Rmsnorm2dFusedQuantEnum
::
SMOOTH_DYNAMIC_QUANT
)
{
base_str
+=
_SS_
(
"_sx"
)
+
_SS_
(
t2s
<
SmoothScaleDataType
>::
name
);
base_str
+=
_SS_
(
"_sy"
)
+
_SS_
(
t2s
<
YScaleDataType
>::
name
);
}
if
(
kFusedQuant
==
Rmsnorm2dFusedQuantEnum
::
DYNAMIC_QUANT
)
{
base_str
+=
_SS_
(
"_sy"
)
+
_SS_
(
t2s
<
YScaleDataType
>::
name
);
}
return
base_str
;
}();
return
_SS_
(
"rmsnorm2d_fwd_"
)
+
_SS_
(
prec_str
)
+
"_"
+
_TS_
(
S_
::
Block_M
)
+
"x"
+
_TS_
(
S_
::
Block_N
)
+
"_"
+
_TS_
(
S_
::
WarpPerBlock_M
)
+
"x"
+
_TS_
(
S_
::
WarpPerBlock_N
)
+
"_"
+
_TS_
(
S_
::
Block_M
)
+
"x"
+
_TS_
(
S_
::
Block_N
)
+
"_"
+
_TS_
(
S_
::
WarpPerBlock_M
)
+
"x"
+
_TS_
(
S_
::
WarpPerBlock_N
)
+
"_"
+
_TS_
(
S_
::
Warp_M
)
+
"x"
+
_TS_
(
S_
::
Warp_N
)
+
"_"
+
_TS_
(
S_
::
Vector_M
)
+
"x"
+
_TS_
(
S_
::
Vector_N
)
+
"_"
+
_TS_
(
S_
::
Warp_M
)
+
"x"
+
_TS_
(
S_
::
Warp_N
)
+
"_"
+
_TS_
(
S_
::
Vector_M
)
+
"x"
+
_TS_
(
S_
::
Vector_N
)
+
"_"
+
_SS_
(
Pipeline
::
name
)
+
surfix
;
_SS_
(
Pipeline
::
name
)
+
surfix
;
#undef _SS_
#undef _TS_
// clang-format on
// clang-format on
#undef _SS_
#undef _TS_
}
}
CK_TILE_DEVICE
void
operator
()(
Kargs
kargs
)
const
CK_TILE_DEVICE
void
operator
()(
Kargs
kargs
)
const
...
@@ -130,7 +179,7 @@ struct Rmsnorm2dFwd
...
@@ -130,7 +179,7 @@ struct Rmsnorm2dFwd
const
auto
tmp_
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
const
auto
tmp_
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
static_cast
<
const
XDataType
*>
(
kargs
.
p_x
),
static_cast
<
const
XDataType
*>
(
kargs
.
p_x
),
make_tuple
(
kargs
.
m
,
kargs
.
n
),
make_tuple
(
kargs
.
m
,
kargs
.
n
),
make_tuple
(
kargs
.
stride
,
1
),
make_tuple
(
kargs
.
x_
stride
,
1
),
number
<
Vector_N
>
{},
number
<
Vector_N
>
{},
number
<
1
>
{});
number
<
1
>
{});
...
@@ -140,6 +189,29 @@ struct Rmsnorm2dFwd
...
@@ -140,6 +189,29 @@ struct Rmsnorm2dFwd
tmp2_
,
make_tuple
(
number
<
Block_M
>
{},
number
<
Block_N
>
{}),
{
iM
,
0
});
tmp2_
,
make_tuple
(
number
<
Block_M
>
{},
number
<
Block_N
>
{}),
{
iM
,
0
});
}();
}();
const
auto
x_residual_window
=
[
&
]()
{
if
constexpr
(
kFusedAdd
==
Rmsnorm2dFusedAddEnum
::
PRE_ADD
||
kFusedAdd
==
Rmsnorm2dFusedAddEnum
::
PRE_ADD_STORE
)
{
const
auto
tmp_
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
static_cast
<
const
XResidualDataType
*>
(
kargs
.
p_x_residual
),
make_tuple
(
kargs
.
m
,
kargs
.
n
),
make_tuple
(
kargs
.
xr_stride
,
1
),
number
<
Vector_N
>
{},
number
<
1
>
{});
const
auto
tmp2_
=
pad_tensor_view
(
tmp_
,
make_tuple
(
number
<
Block_M
>
{},
number
<
Block_N
>
{}),
sequence
<
kPadM
,
kPadN
>
{});
return
make_tile_window
(
tmp2_
,
make_tuple
(
number
<
Block_M
>
{},
number
<
Block_N
>
{}),
{
iM
,
0
});
}
else
{
return
make_null_tile_window
(
make_tuple
(
number
<
Block_M
>
{},
number
<
Block_N
>
{}));
}
}();
const
auto
gamma_window
=
[
&
]()
{
const
auto
gamma_window
=
[
&
]()
{
const
auto
tmp_
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
const
auto
tmp_
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
static_cast
<
const
GammaDataType
*>
(
kargs
.
p_gamma
),
static_cast
<
const
GammaDataType
*>
(
kargs
.
p_gamma
),
...
@@ -158,7 +230,7 @@ struct Rmsnorm2dFwd
...
@@ -158,7 +230,7 @@ struct Rmsnorm2dFwd
auto
tmp_
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
auto
tmp_
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
static_cast
<
YDataType
*>
(
kargs
.
p_y
),
static_cast
<
YDataType
*>
(
kargs
.
p_y
),
make_tuple
(
kargs
.
m
,
kargs
.
n
),
make_tuple
(
kargs
.
m
,
kargs
.
n
),
make_tuple
(
kargs
.
stride
,
1
),
make_tuple
(
kargs
.
y_
stride
,
1
),
number
<
Vector_N
>
{},
number
<
Vector_N
>
{},
number
<
1
>
{});
number
<
1
>
{});
...
@@ -168,6 +240,28 @@ struct Rmsnorm2dFwd
...
@@ -168,6 +240,28 @@ struct Rmsnorm2dFwd
tmp2_
,
make_tuple
(
number
<
Block_M
>
{},
number
<
Block_N
>
{}),
{
iM
,
0
});
tmp2_
,
make_tuple
(
number
<
Block_M
>
{},
number
<
Block_N
>
{}),
{
iM
,
0
});
}();
}();
auto
y_residual_window
=
[
&
]()
{
if
constexpr
(
kFusedAdd
==
Rmsnorm2dFusedAddEnum
::
PRE_ADD_STORE
)
{
auto
tmp_
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
static_cast
<
YResidualDataType
*>
(
kargs
.
p_y_residual
),
make_tuple
(
kargs
.
m
,
kargs
.
n
),
make_tuple
(
kargs
.
yr_stride
,
1
),
number
<
Vector_N
>
{},
number
<
1
>
{});
auto
tmp2_
=
pad_tensor_view
(
tmp_
,
make_tuple
(
number
<
Block_M
>
{},
number
<
Block_N
>
{}),
sequence
<
kPadM
,
kPadN
>
{});
return
make_tile_window
(
tmp2_
,
make_tuple
(
number
<
Block_M
>
{},
number
<
Block_N
>
{}),
{
iM
,
0
});
}
else
{
return
make_null_tile_window
(
make_tuple
(
number
<
Block_M
>
{},
number
<
Block_N
>
{}));
}
}();
auto
inv_rms_window
=
[
&
]()
{
auto
inv_rms_window
=
[
&
]()
{
if
constexpr
(
kSaveInvRms
)
if
constexpr
(
kSaveInvRms
)
{
{
...
@@ -187,15 +281,62 @@ struct Rmsnorm2dFwd
...
@@ -187,15 +281,62 @@ struct Rmsnorm2dFwd
return
make_null_tile_window
(
make_tuple
(
number
<
Block_M
>
{}));
return
make_null_tile_window
(
make_tuple
(
number
<
Block_M
>
{}));
}();
}();
auto
sm_scale_window
=
[
&
]()
{
if
constexpr
(
kFusedQuant
==
Rmsnorm2dFusedQuantEnum
::
SMOOTH_DYNAMIC_QUANT
)
{
const
auto
win_
=
[
&
]()
{
const
auto
tmp_0_
=
make_naive_tensor_view_packed
<
address_space_enum
::
global
>
(
static_cast
<
const
SmoothScaleDataType
*>
(
kargs
.
p_sm_scale
),
make_tuple
(
kargs
.
n
),
number
<
Vector_N
>
{});
return
pad_tensor_view
(
tmp_0_
,
make_tuple
(
number
<
Block_N
>
{}),
sequence
<
false
>
{});
// sm_scale no need pad
}();
return
make_tile_window
(
win_
,
make_tuple
(
number
<
Block_N
>
{}),
{
0
});
}
else
{
return
make_null_tile_window
(
make_tuple
(
number
<
Block_N
>
{}));
}
}();
auto
y_scale_window
=
[
&
]()
{
if
constexpr
(
kFusedQuant
==
Rmsnorm2dFusedQuantEnum
::
SMOOTH_DYNAMIC_QUANT
||
kFusedQuant
==
Rmsnorm2dFusedQuantEnum
::
DYNAMIC_QUANT
)
{
const
auto
win_
=
[
&
]()
{
const
auto
tmp_0_
=
make_naive_tensor_view_packed
<
address_space_enum
::
global
>
(
static_cast
<
YScaleDataType
*>
(
kargs
.
p_y_scale
),
make_tuple
(
kargs
.
m
),
number
<
1
>
{});
return
pad_tensor_view
(
tmp_0_
,
make_tuple
(
number
<
Block_M
>
{}),
sequence
<
kPadM
>
{});
}();
return
make_tile_window
(
win_
,
make_tuple
(
number
<
Block_M
>
{}),
{
iM
});
}
else
{
return
make_null_tile_window
(
make_tuple
(
number
<
Block_M
>
{}));
}
}();
__shared__
char
smem
[
GetSmemSize
()];
__shared__
char
smem
[
GetSmemSize
()];
Pipeline
{}(
x_window
,
Pipeline
{}(
x_window
,
x_residual_window
,
gamma_window
,
gamma_window
,
y_window
,
y_window
,
y_residual_window
,
inv_rms_window
,
inv_rms_window
,
sm_scale_window
,
y_scale_window
,
static_cast
<
const
ComputeDataType
>
(
kargs
.
epsilon
),
static_cast
<
const
ComputeDataType
>
(
kargs
.
epsilon
),
kargs
.
n
,
kargs
.
n
,
smem
);
smem
,
Epilogue
{});
}
}
};
};
...
...
include/ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_default_policy.hpp
View file @
a7ae4f8e
...
@@ -45,7 +45,7 @@ struct Rmsnorm2dFwdPipelineDefaultPolicy
...
@@ -45,7 +45,7 @@ struct Rmsnorm2dFwdPipelineDefaultPolicy
template
<
typename
Problem
>
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetBlockReduce2d
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetBlockReduce2d
()
{
{
using
P_
=
BlockReduce2dProblem
<
typename
Problem
::
X
DataType
,
using
P_
=
BlockReduce2dProblem
<
typename
Problem
::
Compute
DataType
,
typename
Problem
::
ComputeDataType
,
typename
Problem
::
ComputeDataType
,
typename
Problem
::
BlockShape
>
;
typename
Problem
::
BlockShape
>
;
return
BlockReduce2d
<
P_
>
{};
return
BlockReduce2d
<
P_
>
{};
...
@@ -54,7 +54,7 @@ struct Rmsnorm2dFwdPipelineDefaultPolicy
...
@@ -54,7 +54,7 @@ struct Rmsnorm2dFwdPipelineDefaultPolicy
template
<
typename
Problem
>
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetBlockReduce2dSync
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetBlockReduce2dSync
()
{
{
using
P_
=
BlockReduce2dProblem
<
typename
Problem
::
X
DataType
,
using
P_
=
BlockReduce2dProblem
<
typename
Problem
::
Compute
DataType
,
typename
Problem
::
ComputeDataType
,
typename
Problem
::
ComputeDataType
,
typename
Problem
::
BlockShape
>
;
typename
Problem
::
BlockShape
>
;
return
BlockReduce2dSync
<
P_
>
{};
return
BlockReduce2dSync
<
P_
>
{};
...
@@ -63,7 +63,7 @@ struct Rmsnorm2dFwdPipelineDefaultPolicy
...
@@ -63,7 +63,7 @@ struct Rmsnorm2dFwdPipelineDefaultPolicy
template
<
typename
Problem
>
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetBlockReduce2dCrossWarpSync
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetBlockReduce2dCrossWarpSync
()
{
{
using
P_
=
BlockReduce2dProblem
<
typename
Problem
::
X
DataType
,
using
P_
=
BlockReduce2dProblem
<
typename
Problem
::
Compute
DataType
,
typename
Problem
::
ComputeDataType
,
typename
Problem
::
ComputeDataType
,
typename
Problem
::
BlockShape
>
;
typename
Problem
::
BlockShape
>
;
return
BlockReduce2dCrossWarpSync
<
P_
>
{};
return
BlockReduce2dCrossWarpSync
<
P_
>
{};
...
@@ -74,13 +74,13 @@ struct Rmsnorm2dFwdPipelineDefaultPolicy
...
@@ -74,13 +74,13 @@ struct Rmsnorm2dFwdPipelineDefaultPolicy
{
{
if
constexpr
(
Problem
::
kNeedCrossWarpSync
)
if
constexpr
(
Problem
::
kNeedCrossWarpSync
)
{
{
using
P_
=
BlockReduce2dProblem
<
typename
Problem
::
X
DataType
,
using
P_
=
BlockReduce2dProblem
<
typename
Problem
::
Compute
DataType
,
typename
Problem
::
ComputeDataType
,
typename
Problem
::
ComputeDataType
,
typename
Problem
::
BlockShape
>
;
typename
Problem
::
BlockShape
>
;
using
block_reduce2d
=
BlockReduce2d
<
P_
>
;
using
block_reduce2d
=
BlockReduce2d
<
P_
>
;
using
x_block_tile
=
using
x_block_tile
=
decltype
(
make_static_distributed_tensor
<
typename
Problem
::
X
DataType
>
(
decltype
(
make_static_distributed_tensor
<
typename
Problem
::
Compute
DataType
>
(
MakeXBlockTileDistribution
<
Problem
>
()));
MakeXBlockTileDistribution
<
Problem
>
()));
using
y_block_tile
=
decltype
(
block_reduce2d
::
template
MakeYBlockTile
<
x_block_tile
>());
using
y_block_tile
=
decltype
(
block_reduce2d
::
template
MakeYBlockTile
<
x_block_tile
>());
...
...
include/ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_one_pass.hpp
View file @
a7ae4f8e
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
5
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
...
@@ -22,12 +22,17 @@ struct Rmsnorm2dFwdPipelineOnePass
...
@@ -22,12 +22,17 @@ struct Rmsnorm2dFwdPipelineOnePass
using
YDataType
=
ck_tile
::
remove_cvref_t
<
typename
Problem
::
YDataType
>
;
using
YDataType
=
ck_tile
::
remove_cvref_t
<
typename
Problem
::
YDataType
>
;
using
InvRmsDataType
=
ck_tile
::
remove_cvref_t
<
typename
Problem
::
InvRmsDataType
>
;
using
InvRmsDataType
=
ck_tile
::
remove_cvref_t
<
typename
Problem
::
InvRmsDataType
>
;
using
XResidualDataType
=
XDataType
;
using
YResidualDataType
=
XDataType
;
static
constexpr
bool
kHasGamma
=
!
std
::
is_same_v
<
GammaDataType
,
ck_tile
::
null_type
>
;
static
constexpr
bool
kHasGamma
=
!
std
::
is_same_v
<
GammaDataType
,
ck_tile
::
null_type
>
;
static
constexpr
bool
kSaveInvRms
=
Problem
::
kSaveInvRms
;
static
constexpr
bool
kSaveInvRms
=
Problem
::
Traits
::
kSaveInvRms
;
static
constexpr
bool
kNeedCrossWarpSync
=
Problem
::
kNeedCrossWarpSync
;
static
constexpr
bool
kNeedCrossWarpSync
=
Problem
::
kNeedCrossWarpSync
;
static
constexpr
bool
kPadM
=
false
;
// TODO - BlockRmsnorm2dFwdProblem::kPadM
static
constexpr
bool
kPadM
=
false
;
// TODO - BlockRmsnorm2dFwdProblem::kPadM
static
constexpr
bool
kPadN
=
Problem
::
kPadN
;
static
constexpr
bool
kPadN
=
Problem
::
Traits
::
kPadN
;
static
constexpr
auto
kFusedAdd
=
Problem
::
Traits
::
kFusedAdd
;
static
constexpr
auto
kFusedQuant
=
Problem
::
Traits
::
kFusedQuant
;
static
constexpr
const
char
*
name
=
[]()
{
static
constexpr
const
char
*
name
=
[]()
{
if
constexpr
(
kNeedCrossWarpSync
)
if
constexpr
(
kNeedCrossWarpSync
)
...
@@ -41,19 +46,36 @@ struct Rmsnorm2dFwdPipelineOnePass
...
@@ -41,19 +46,36 @@ struct Rmsnorm2dFwdPipelineOnePass
return
Policy
::
template
GetSmemSize
<
Problem
>();
return
Policy
::
template
GetSmemSize
<
Problem
>();
}
}
template
<
typename
XWindow
,
typename
GammaWindow
,
typename
YWindow
,
typename
InvRmsWindow
>
template
<
typename
XWindow
,
typename
XResidualWindow
,
typename
GammaWindow
,
typename
YWindow
,
typename
YResidualWindow
,
typename
InvRmsWindow
,
typename
SmoothScaleWindow
,
typename
YScaleWindow
,
typename
Epilogue
>
CK_TILE_DEVICE
auto
operator
()(
const
XWindow
&
x_window_
,
CK_TILE_DEVICE
auto
operator
()(
const
XWindow
&
x_window_
,
const
XResidualWindow
&
x_residual_window_
,
const
GammaWindow
&
gamma_window_
,
const
GammaWindow
&
gamma_window_
,
YWindow
&
y_window
,
YWindow
&
y_window_
,
const
YResidualWindow
&
y_residual_window_
,
InvRmsWindow
&
inv_rms_window
,
InvRmsWindow
&
inv_rms_window
,
const
SmoothScaleWindow
&
sm_scale_window_
,
YScaleWindow
&
y_scale_window_
,
ComputeDataType
epsilon
,
ComputeDataType
epsilon
,
ck_tile
::
index_t
row_size
,
ck_tile
::
index_t
row_size
,
void
*
smem
)
const
void
*
smem
,
Epilogue
)
const
{
{
const
auto
x_window
=
const
auto
x_window
=
make_tile_window
(
x_window_
,
Policy
::
template
MakeXBlockTileDistribution
<
Problem
>());
make_tile_window
(
x_window_
,
Policy
::
template
MakeXBlockTileDistribution
<
Problem
>());
const
auto
gamma_window
=
make_tile_window
(
const
auto
gamma_window
=
make_tile_window
(
gamma_window_
,
Policy
::
template
MakeGammaBlockTileDistribution
<
Problem
>());
gamma_window_
,
Policy
::
template
MakeGammaBlockTileDistribution
<
Problem
>());
const
auto
x_residual_window
=
make_tile_window
(
x_residual_window_
,
Policy
::
template
MakeXBlockTileDistribution
<
Problem
>());
auto
y_residual_window
=
make_tile_window
(
y_residual_window_
,
Policy
::
template
MakeXBlockTileDistribution
<
Problem
>());
auto
reduce_square_sum_func
=
ReduceOp
::
SquareAdd
{};
auto
reduce_square_sum_func
=
ReduceOp
::
SquareAdd
{};
auto
reduce_sum_func
=
ReduceOp
::
Add
{};
auto
reduce_sum_func
=
ReduceOp
::
Add
{};
...
@@ -62,13 +84,31 @@ struct Rmsnorm2dFwdPipelineOnePass
...
@@ -62,13 +84,31 @@ struct Rmsnorm2dFwdPipelineOnePass
auto
block_reduce2d_cross_warp_sync
=
auto
block_reduce2d_cross_warp_sync
=
Policy
::
template
GetBlockReduce2dCrossWarpSync
<
Problem
>();
Policy
::
template
GetBlockReduce2dCrossWarpSync
<
Problem
>();
const
auto
x
=
load_tile
(
x_window
);
auto
x
=
load_tile
(
x_window
);
auto
x_resi
=
load_tile
(
x_residual_window
);
// load gamma (TODO: support no gamma?)
// load gamma (TODO: support no gamma?)
const
auto
gamma
=
load_tile
(
gamma_window
);
const
auto
gamma
=
load_tile
(
gamma_window
);
auto
acc
=
cast_tile
<
ComputeDataType
>
(
x
);
if
constexpr
(
kFusedAdd
==
Rmsnorm2dFusedAddEnum
::
PRE_ADD
||
kFusedAdd
==
Rmsnorm2dFusedAddEnum
::
PRE_ADD_STORE
)
{
sweep_tile
(
x_resi
,
[
&
](
auto
idx
)
{
// compute x = x_resi + x
acc
(
idx
)
=
type_convert
<
ComputeDataType
>
(
x_resi
(
idx
))
+
acc
(
idx
);
});
if
constexpr
(
kFusedAdd
==
Rmsnorm2dFusedAddEnum
::
PRE_ADD_STORE
)
{
store_tile
(
y_residual_window
,
cast_tile
<
YResidualDataType
>
(
acc
));
}
}
// compute mean square each-thread->cross-lane->cross-warp
// compute mean square each-thread->cross-lane->cross-warp
auto
square_sum
=
block_reduce2d
(
auto
square_sum
=
block_reduce2d
(
acc
,
x
,
reduce_square_sum_func
.
GetIdentityValue
<
ComputeDataType
>
(),
reduce_square_sum_func
);
reduce_square_sum_func
.
GetIdentityValue
<
ComputeDataType
>
(),
reduce_square_sum_func
);
block_reduce2d_sync
(
square_sum
,
reduce_sum_func
);
block_reduce2d_sync
(
square_sum
,
reduce_sum_func
);
block_reduce2d_cross_warp_sync
(
square_sum
,
smem
,
reduce_sum_func
);
block_reduce2d_cross_warp_sync
(
square_sum
,
smem
,
reduce_sum_func
);
...
@@ -83,19 +123,30 @@ struct Rmsnorm2dFwdPipelineOnePass
...
@@ -83,19 +123,30 @@ struct Rmsnorm2dFwdPipelineOnePass
store_tile
(
inv_rms_window
,
cast_tile
<
InvRmsDataType
>
(
inv_rms
));
store_tile
(
inv_rms_window
,
cast_tile
<
InvRmsDataType
>
(
inv_rms
));
// rmsnorm computation
// rmsnorm computation
auto
y
=
make_static_distributed_tensor
<
Y
DataType
>
(
x
.
get_tile_distribution
());
auto
rmsn
=
make_static_distributed_tensor
<
Compute
DataType
>
(
x
.
get_tile_distribution
());
sweep_tile
(
y
,
[
&
,
inv_rms_
=
inv_rms
](
auto
idx
)
{
sweep_tile
(
rmsn
,
[
&
,
inv_rms_
=
inv_rms
](
auto
idx
)
{
constexpr
auto
i_idx
=
make_tuple
(
idx
[
number
<
0
>
{}]);
constexpr
auto
i_idx
=
make_tuple
(
idx
[
number
<
0
>
{}]);
constexpr
auto
j_idx
=
make_tuple
(
idx
[
number
<
1
>
{}]);
constexpr
auto
j_idx
=
make_tuple
(
idx
[
number
<
1
>
{}]);
const
auto
gamma_
=
type_convert
<
ComputeDataType
>
(
gamma
[
j_idx
]);
const
auto
gamma_
=
type_convert
<
ComputeDataType
>
(
gamma
[
j_idx
]);
const
auto
x_
=
type_convert
<
ComputeDataType
>
(
x
[
idx
]);
auto
rmsn_
=
acc
[
idx
]
*
inv_rms_
[
i_idx
]
*
gamma_
;
auto
y_
=
x_
*
inv_rms_
[
i_idx
]
*
gamma_
;
y
(
idx
)
=
type_convert
<
YDataType
>
(
y_
)
;
rmsn
(
idx
)
=
rmsn_
;
});
});
store_tile
(
y_window
,
y
);
if
constexpr
(
kFusedQuant
==
Rmsnorm2dFusedQuantEnum
::
SMOOTH_DYNAMIC_QUANT
)
{
Epilogue
{}(
y_window_
,
sm_scale_window_
,
y_scale_window_
,
rmsn
,
smem
);
}
else
if
constexpr
(
kFusedQuant
==
Rmsnorm2dFusedQuantEnum
::
DYNAMIC_QUANT
)
{
Epilogue
{}(
y_window_
,
y_scale_window_
,
rmsn
,
smem
);
}
else
{
Epilogue
{}(
y_window_
,
rmsn
);
}
}
}
};
};
}
// namespace ck_tile
}
// namespace ck_tile
include/ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_problem.hpp
View file @
a7ae4f8e
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
5
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
...
@@ -12,25 +12,25 @@ template <typename XDataType_,
...
@@ -12,25 +12,25 @@ template <typename XDataType_,
typename
ComputeDataType_
,
typename
ComputeDataType_
,
typename
YDataType_
,
typename
YDataType_
,
typename
InvRmsDataType_
,
typename
InvRmsDataType_
,
typename
SmoothScaleDataType_
,
typename
YScaleDataType_
,
typename
BlockShape_
,
typename
BlockShape_
,
bool
kPadN_
,
typename
Traits_
>
bool
kSaveInvRms_
,
bool
kTwoPass_
>
struct
Rmsnorm2dFwdPipelineProblem
struct
Rmsnorm2dFwdPipelineProblem
{
{
using
XDataType
=
remove_cvref_t
<
XDataType_
>
;
using
XDataType
=
remove_cvref_t
<
XDataType_
>
;
using
GammaDataType
=
remove_cvref_t
<
GammaDataType_
>
;
using
GammaDataType
=
remove_cvref_t
<
GammaDataType_
>
;
using
ComputeDataType
=
remove_cvref_t
<
ComputeDataType_
>
;
using
ComputeDataType
=
remove_cvref_t
<
ComputeDataType_
>
;
using
YDataType
=
remove_cvref_t
<
YDataType_
>
;
using
YDataType
=
remove_cvref_t
<
YDataType_
>
;
using
InvRmsDataType
=
remove_cvref_t
<
InvRmsDataType_
>
;
using
InvRmsDataType
=
remove_cvref_t
<
InvRmsDataType_
>
;
using
BlockShape
=
remove_cvref_t
<
BlockShape_
>
;
using
SmoothScaleDataType
=
remove_cvref_t
<
SmoothScaleDataType_
>
;
using
YScaleDataType
=
remove_cvref_t
<
YScaleDataType_
>
;
using
BlockShape
=
remove_cvref_t
<
BlockShape_
>
;
static
constexpr
bool
kNeedCrossLaneSync
=
BlockShape
::
ThreadPerWarp_N
>
1
;
static
constexpr
bool
kNeedCrossLaneSync
=
BlockShape
::
ThreadPerWarp_N
>
1
;
static
constexpr
bool
kNeedCrossWarpSync
=
BlockShape
::
WarpPerBlock_N
>
1
;
static
constexpr
bool
kNeedCrossWarpSync
=
BlockShape
::
WarpPerBlock_N
>
1
;
static
constexpr
bool
kPadN
=
kPadN_
;
using
Traits
=
remove_cvref_t
<
Traits_
>
;
static
constexpr
bool
kSaveInvRms
=
kSaveInvRms_
;
static
constexpr
bool
kTwoPass
=
kTwoPass_
;
};
};
}
// namespace ck_tile
}
// namespace ck_tile
include/ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_two_pass.hpp
View file @
a7ae4f8e
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
5
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
...
@@ -22,12 +22,17 @@ struct Rmsnorm2dFwdPipelineTwoPass
...
@@ -22,12 +22,17 @@ struct Rmsnorm2dFwdPipelineTwoPass
using
YDataType
=
ck_tile
::
remove_cvref_t
<
typename
Problem
::
YDataType
>
;
using
YDataType
=
ck_tile
::
remove_cvref_t
<
typename
Problem
::
YDataType
>
;
using
InvRmsDataType
=
ck_tile
::
remove_cvref_t
<
typename
Problem
::
InvRmsDataType
>
;
using
InvRmsDataType
=
ck_tile
::
remove_cvref_t
<
typename
Problem
::
InvRmsDataType
>
;
using
XResidualDataType
=
XDataType
;
using
YResidualDataType
=
XDataType
;
static
constexpr
bool
kHasGamma
=
!
std
::
is_same_v
<
GammaDataType
,
ck_tile
::
null_type
>
;
static
constexpr
bool
kHasGamma
=
!
std
::
is_same_v
<
GammaDataType
,
ck_tile
::
null_type
>
;
static
constexpr
bool
kSaveInvRms
=
Problem
::
kSaveInvRms
;
static
constexpr
bool
kSaveInvRms
=
Problem
::
Traits
::
kSaveInvRms
;
static
constexpr
bool
kNeedCrossWarpSync
=
Problem
::
kNeedCrossWarpSync
;
static
constexpr
bool
kNeedCrossWarpSync
=
Problem
::
kNeedCrossWarpSync
;
static
constexpr
bool
kPadM
=
false
;
// TODO - BlockRmsnorm2dFwdProblem::kPadM
static
constexpr
bool
kPadM
=
false
;
// TODO - BlockRmsnorm2dFwdProblem::kPadM
static
constexpr
bool
kPadN
=
Problem
::
kPadN
;
static
constexpr
bool
kPadN
=
Problem
::
Traits
::
kPadN
;
static
constexpr
auto
kFusedAdd
=
Problem
::
Traits
::
kFusedAdd
;
static
constexpr
auto
kFusedQuant
=
Problem
::
Traits
::
kFusedQuant
;
static
constexpr
const
char
*
name
=
[]()
{
static
constexpr
const
char
*
name
=
[]()
{
if
constexpr
(
kNeedCrossWarpSync
)
if
constexpr
(
kNeedCrossWarpSync
)
...
@@ -41,19 +46,36 @@ struct Rmsnorm2dFwdPipelineTwoPass
...
@@ -41,19 +46,36 @@ struct Rmsnorm2dFwdPipelineTwoPass
return
Policy
::
template
GetSmemSize
<
Problem
>();
return
Policy
::
template
GetSmemSize
<
Problem
>();
}
}
template
<
typename
XWindow
,
typename
GammaWindow
,
typename
YWindow
,
typename
InvRmsWindow
>
template
<
typename
XWindow
,
typename
XResidualWindow
,
typename
GammaWindow
,
typename
YWindow
,
typename
YResidualWindow
,
typename
InvRmsWindow
,
typename
SmoothScaleWindow
,
typename
YScaleWindow
,
typename
Epilogue
>
CK_TILE_DEVICE
auto
operator
()(
const
XWindow
&
x_window_
,
CK_TILE_DEVICE
auto
operator
()(
const
XWindow
&
x_window_
,
const
XResidualWindow
&
x_residual_window_
,
const
GammaWindow
&
gamma_window_
,
const
GammaWindow
&
gamma_window_
,
YWindow
&
y_window
,
YWindow
&
y_window
,
const
YResidualWindow
&
y_residual_window_
,
InvRmsWindow
&
inv_rms_window
,
InvRmsWindow
&
inv_rms_window
,
const
SmoothScaleWindow
&
/*sm_scale_window_*/
,
YScaleWindow
&
/*y_scale_window*/
,
ComputeDataType
epsilon
,
ComputeDataType
epsilon
,
ck_tile
::
index_t
row_size
,
ck_tile
::
index_t
row_size
,
void
*
smem
)
const
void
*
smem
,
Epilogue
)
const
{
{
auto
x_window
=
auto
x_window
=
make_tile_window
(
x_window_
,
Policy
::
template
MakeXBlockTileDistribution
<
Problem
>());
make_tile_window
(
x_window_
,
Policy
::
template
MakeXBlockTileDistribution
<
Problem
>());
auto
gamma_window
=
make_tile_window
(
auto
gamma_window
=
make_tile_window
(
gamma_window_
,
Policy
::
template
MakeGammaBlockTileDistribution
<
Problem
>());
gamma_window_
,
Policy
::
template
MakeGammaBlockTileDistribution
<
Problem
>());
auto
x_residual_window
=
make_tile_window
(
x_residual_window_
,
Policy
::
template
MakeXBlockTileDistribution
<
Problem
>());
auto
y_residual_window
=
make_tile_window
(
y_residual_window_
,
Policy
::
template
MakeXBlockTileDistribution
<
Problem
>());
// Problem::BlockShape
// Problem::BlockShape
static
constexpr
index_t
Block_N
=
Problem
::
BlockShape
::
Block_N
;
static
constexpr
index_t
Block_N
=
Problem
::
BlockShape
::
Block_N
;
...
@@ -67,15 +89,34 @@ struct Rmsnorm2dFwdPipelineTwoPass
...
@@ -67,15 +89,34 @@ struct Rmsnorm2dFwdPipelineTwoPass
auto
block_reduce2d_cross_warp_sync
=
auto
block_reduce2d_cross_warp_sync
=
Policy
::
template
GetBlockReduce2dCrossWarpSync
<
Problem
>();
Policy
::
template
GetBlockReduce2dCrossWarpSync
<
Problem
>();
using
X
TensorType
=
decltype
(
load_tile
(
x_window
));
using
Compute
TensorType
=
decltype
(
cast_tile
<
ComputeDataType
>
(
load_tile
(
x_window
))
)
;
auto
square_sum
=
block_reduce2d
.
template
MakeYBlockTile
<
X
TensorType
>();
auto
square_sum
=
block_reduce2d
.
template
MakeYBlockTile
<
Compute
TensorType
>();
set_tile
(
square_sum
,
reduce_square_sum_func
.
GetIdentityValue
<
ComputeDataType
>
());
set_tile
(
square_sum
,
reduce_square_sum_func
.
GetIdentityValue
<
ComputeDataType
>
());
for
(
int
iN
=
__builtin_amdgcn_readfirstlane
(
0
);
iN
<
num_n_tile_iteration
;
++
iN
)
for
(
int
iN
=
__builtin_amdgcn_readfirstlane
(
0
);
iN
<
num_n_tile_iteration
;
++
iN
)
{
{
const
auto
x
=
load_tile
(
x_window
);
auto
x
=
load_tile
(
x_window
);
block_reduce2d
(
x
,
square_sum
,
reduce_square_sum_func
);
auto
x_resi
=
load_tile
(
x_residual_window
);
move_tile_window
(
x_window
,
{
0
,
Block_N
});
move_tile_window
(
x_window
,
{
0
,
Block_N
});
move_tile_window
(
x_residual_window
,
{
0
,
Block_N
});
auto
acc
=
cast_tile
<
ComputeDataType
>
(
x
);
if
constexpr
(
kFusedAdd
==
Rmsnorm2dFusedAddEnum
::
PRE_ADD
||
kFusedAdd
==
Rmsnorm2dFusedAddEnum
::
PRE_ADD_STORE
)
{
sweep_tile
(
x_resi
,
[
&
](
auto
idx
)
{
// compute x = x_resi + x
acc
(
idx
)
=
type_convert
<
ComputeDataType
>
(
x_resi
(
idx
))
+
acc
(
idx
);
});
if
constexpr
(
kFusedAdd
==
Rmsnorm2dFusedAddEnum
::
PRE_ADD_STORE
)
{
store_tile
(
y_residual_window
,
cast_tile
<
YResidualDataType
>
(
acc
));
move_tile_window
(
y_residual_window
,
{
0
,
Block_N
});
}
}
block_reduce2d
(
acc
,
square_sum
,
reduce_square_sum_func
);
}
}
block_reduce2d_sync
(
square_sum
,
reduce_sum_func
);
block_reduce2d_sync
(
square_sum
,
reduce_sum_func
);
...
@@ -96,33 +137,47 @@ struct Rmsnorm2dFwdPipelineTwoPass
...
@@ -96,33 +137,47 @@ struct Rmsnorm2dFwdPipelineTwoPass
row_size
%
Block_N
==
0
?
row_size
-
Block_N
:
row_size
-
row_size
%
Block_N
;
row_size
%
Block_N
==
0
?
row_size
-
Block_N
:
row_size
-
row_size
%
Block_N
;
move_tile_window
(
x_window
,
{
0
,
-
Block_N
});
move_tile_window
(
x_window
,
{
0
,
-
Block_N
});
move_tile_window
(
x_residual_window
,
{
0
,
-
Block_N
});
move_tile_window
(
gamma_window
,
{
stride_to_right_most_window
});
move_tile_window
(
gamma_window
,
{
stride_to_right_most_window
});
move_tile_window
(
y_window
,
{
0
,
stride_to_right_most_window
});
move_tile_window
(
y_window
,
{
0
,
stride_to_right_most_window
});
// rmsnorm computation
// rmsnorm computation
for
(
int
iN
=
__builtin_amdgcn_readfirstlane
(
0
);
iN
<
num_n_tile_iteration
;
++
iN
)
for
(
int
iN
=
__builtin_amdgcn_readfirstlane
(
0
);
iN
<
num_n_tile_iteration
;
++
iN
)
{
{
const
auto
x
=
load_tile
(
x_window
);
auto
x
=
load_tile
(
x_window
);
// load gamma/beta (TODO: support no gamma/beta?)
auto
x_resi
=
load_tile
(
x_residual_window
);
auto
acc
=
cast_tile
<
ComputeDataType
>
(
x
);
if
constexpr
(
kFusedAdd
==
Rmsnorm2dFusedAddEnum
::
PRE_ADD_STORE
||
kFusedAdd
==
Rmsnorm2dFusedAddEnum
::
PRE_ADD
)
{
sweep_tile
(
x_resi
,
[
&
](
auto
idx
)
{
// compute x = x_resi + x
acc
(
idx
)
=
type_convert
<
ComputeDataType
>
(
x_resi
(
idx
))
+
acc
(
idx
);
});
}
// load gamma (TODO: support no gamma?)
const
auto
gamma
=
load_tile
(
gamma_window
);
const
auto
gamma
=
load_tile
(
gamma_window
);
auto
y
=
make_static_distributed_tensor
<
YDataType
>
(
x
.
get_tile_distribution
());
// rmsnorm computation
auto
rmsn
=
make_static_distributed_tensor
<
ComputeDataType
>
(
x
.
get_tile_distribution
());
sweep_tile
(
y
,
[
&
,
inv_rms_
=
inv_rms
](
auto
idx
)
{
sweep_tile
(
rmsn
,
[
&
,
inv_rms_
=
inv_rms
](
auto
idx
)
{
constexpr
auto
i_idx
=
make_tuple
(
idx
[
number
<
0
>
{}]);
constexpr
auto
i_idx
=
make_tuple
(
idx
[
number
<
0
>
{}]);
constexpr
auto
j_idx
=
make_tuple
(
idx
[
number
<
1
>
{}]);
constexpr
auto
j_idx
=
make_tuple
(
idx
[
number
<
1
>
{}]);
const
auto
gamma_
=
type_convert
<
ComputeDataType
>
(
gamma
[
j_idx
]);
const
auto
gamma_
=
type_convert
<
ComputeDataType
>
(
gamma
[
j_idx
]);
const
auto
x_
=
type_convert
<
ComputeDataType
>
(
x
[
idx
]);
auto
rmsn_
=
acc
(
idx
)
*
inv_rms_
[
i_idx
]
*
gamma_
;
auto
y_
=
x_
*
inv_rms_
[
i_idx
]
*
gamma_
;
y
(
idx
)
=
type_convert
<
YDataType
>
(
y_
)
;
rmsn
(
idx
)
=
rmsn_
;
});
});
store_tile
(
y_window
,
y
);
static_assert
(
kFusedQuant
==
Rmsnorm2dFusedQuantEnum
::
NO_SWEEP
);
Epilogue
{}(
y_window
,
rmsn
);
move_tile_window
(
x_window
,
{
0
,
-
Block_N
});
move_tile_window
(
x_window
,
{
0
,
-
Block_N
});
move_tile_window
(
x_residual_window
,
{
0
,
-
Block_N
});
move_tile_window
(
gamma_window
,
{
-
Block_N
});
move_tile_window
(
gamma_window
,
{
-
Block_N
});
move_tile_window
(
y_window
,
{
0
,
-
Block_N
});
move_tile_window
(
y_window
,
{
0
,
-
Block_N
});
}
}
...
...
include/ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_traits.hpp
0 → 100644
View file @
a7ae4f8e
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/utility/type_traits.hpp"
namespace
ck_tile
{
enum
class
Rmsnorm2dFusedAddEnum
{
NO_ADD
=
0
,
// fused add before RMSNorm and store result to global
PRE_ADD_STORE
=
1
,
// fused add before RMSNorm, but not store result
PRE_ADD
=
2
,
};
// clang-format off
template
<
Rmsnorm2dFusedAddEnum
>
struct
Rmsnorm2dFusedAddEnumName
;
template
<
>
struct
Rmsnorm2dFusedAddEnumName
<
Rmsnorm2dFusedAddEnum
::
NO_ADD
>
{
static
constexpr
const
char
*
name
=
"no"
;
};
template
<
>
struct
Rmsnorm2dFusedAddEnumName
<
Rmsnorm2dFusedAddEnum
::
PRE_ADD_STORE
>
{
static
constexpr
const
char
*
name
=
"pras"
;
};
template
<
>
struct
Rmsnorm2dFusedAddEnumName
<
Rmsnorm2dFusedAddEnum
::
PRE_ADD
>
{
static
constexpr
const
char
*
name
=
"pra"
;
};
// clang-format on
enum
class
Rmsnorm2dFusedQuantEnum
{
NO_SWEEP
=
0
,
SMOOTH_DYNAMIC_QUANT
=
1
,
// smooth oulier + rowwise quant, need input x-scale and store y_scale
DYNAMIC_QUANT
=
2
,
// rowwise quant, store out a y-scale
};
// clang-format off
template
<
Rmsnorm2dFusedQuantEnum
>
struct
Rmsnorm2dFusedQuantEnumName
;
template
<
>
struct
Rmsnorm2dFusedQuantEnumName
<
Rmsnorm2dFusedQuantEnum
::
NO_SWEEP
>
{
static
constexpr
const
char
*
name
=
"no"
;
};
template
<
>
struct
Rmsnorm2dFusedQuantEnumName
<
Rmsnorm2dFusedQuantEnum
::
DYNAMIC_QUANT
>
{
static
constexpr
const
char
*
name
=
"dqt"
;
};
template
<
>
struct
Rmsnorm2dFusedQuantEnumName
<
Rmsnorm2dFusedQuantEnum
::
SMOOTH_DYNAMIC_QUANT
>
{
static
constexpr
const
char
*
name
=
"smdqt"
;
};
// clang-format on
template
<
bool
kPadN_
,
bool
kSaveInvRms_
,
bool
kTwoPass_
,
Rmsnorm2dFusedAddEnum
kFusedAdd_
,
Rmsnorm2dFusedQuantEnum
kFusedQuant_
>
struct
Rmsnorm2dFwdTraits
{
static
constexpr
bool
kPadN
=
kPadN_
;
static
constexpr
bool
kSaveInvRms
=
kSaveInvRms_
;
static
constexpr
bool
kTwoPass
=
kTwoPass_
;
static
constexpr
Rmsnorm2dFusedAddEnum
kFusedAdd
=
kFusedAdd_
;
static
constexpr
Rmsnorm2dFusedQuantEnum
kFusedQuant
=
kFusedQuant_
;
};
}
// namespace ck_tile
include/ck_tile/ops/smoothquant/kernel/moe_smoothquant_kernel.hpp
View file @
a7ae4f8e
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
5
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
...
@@ -12,7 +12,7 @@ namespace ck_tile {
...
@@ -12,7 +12,7 @@ namespace ck_tile {
struct
MoeSmoothquantHostArgs
struct
MoeSmoothquantHostArgs
{
{
const
void
*
p_x
;
// [tokens ,hidden_size], input, fp16/bf16
const
void
*
p_x
;
// [tokens ,hidden_size], input, fp16/bf16
const
void
*
p_
x
scale
;
// [experts, hidden_size], input, columnwise scale, fp32
const
void
*
p_
sm
scale
;
// [experts, hidden_size], input, columnwise scale, fp32
const
void
*
p_topk_ids
;
// [tokens, topk]
const
void
*
p_topk_ids
;
// [tokens, topk]
void
*
p_yscale
;
// [topk * tokens, 1], output, rowwise quant scale
void
*
p_yscale
;
// [topk * tokens, 1], output, rowwise quant scale
...
@@ -33,11 +33,11 @@ struct MoeSmoothquant
...
@@ -33,11 +33,11 @@ struct MoeSmoothquant
using
Pipeline
=
remove_cvref_t
<
Pipeline_
>
;
using
Pipeline
=
remove_cvref_t
<
Pipeline_
>
;
using
Problem
=
typename
Pipeline
::
Problem
;
using
Problem
=
typename
Pipeline
::
Problem
;
using
XDataType
=
remove_cvref_t
<
typename
Problem
::
XDataType
>
;
using
XDataType
=
remove_cvref_t
<
typename
Problem
::
XDataType
>
;
using
X
ScaleDataType
=
remove_cvref_t
<
typename
Problem
::
X
ScaleDataType
>
;
using
Smooth
ScaleDataType
=
remove_cvref_t
<
typename
Problem
::
Smooth
ScaleDataType
>
;
using
ComputeDataType
=
remove_cvref_t
<
typename
Problem
::
ComputeDataType
>
;
using
ComputeDataType
=
remove_cvref_t
<
typename
Problem
::
ComputeDataType
>
;
using
YScaleDataType
=
remove_cvref_t
<
typename
Problem
::
YScaleDataType
>
;
using
YScaleDataType
=
remove_cvref_t
<
typename
Problem
::
YScaleDataType
>
;
using
QYDataType
=
remove_cvref_t
<
typename
Problem
::
QYDataType
>
;
using
QYDataType
=
remove_cvref_t
<
typename
Problem
::
QYDataType
>
;
static
constexpr
index_t
Block_M
=
Problem
::
BlockShape
::
Block_M
;
static
constexpr
index_t
Block_M
=
Problem
::
BlockShape
::
Block_M
;
static
constexpr
index_t
Block_N
=
Problem
::
BlockShape
::
Block_N
;
static
constexpr
index_t
Block_N
=
Problem
::
BlockShape
::
Block_N
;
...
@@ -57,7 +57,7 @@ struct MoeSmoothquant
...
@@ -57,7 +57,7 @@ struct MoeSmoothquant
struct
Kargs
struct
Kargs
{
{
const
void
*
p_x
;
// [tokens ,hidden_size], input, fp16/bf16
const
void
*
p_x
;
// [tokens ,hidden_size], input, fp16/bf16
const
void
*
p_
x
scale
;
// [experts, hidden_size], input, columnwise scale, fp32
const
void
*
p_
sm
scale
;
// [experts, hidden_size], input, columnwise scale, fp32
const
void
*
p_topk_ids
;
// [tokens, topk]
const
void
*
p_topk_ids
;
// [tokens, topk]
void
*
p_yscale
;
// [topk, tokens, 1], output, rowwise quant scale
void
*
p_yscale
;
// [topk, tokens, 1], output, rowwise quant scale
...
@@ -75,7 +75,7 @@ struct MoeSmoothquant
...
@@ -75,7 +75,7 @@ struct MoeSmoothquant
CK_TILE_HOST
static
constexpr
Kargs
MakeKargs
(
const
Hargs
&
hargs
)
CK_TILE_HOST
static
constexpr
Kargs
MakeKargs
(
const
Hargs
&
hargs
)
{
{
return
Kargs
{
hargs
.
p_x
,
return
Kargs
{
hargs
.
p_x
,
hargs
.
p_
x
scale
,
hargs
.
p_
sm
scale
,
hargs
.
p_topk_ids
,
hargs
.
p_topk_ids
,
hargs
.
p_yscale
,
hargs
.
p_yscale
,
hargs
.
p_qy
,
hargs
.
p_qy
,
...
@@ -101,6 +101,7 @@ struct MoeSmoothquant
...
@@ -101,6 +101,7 @@ struct MoeSmoothquant
template
<
>
struct
t2s
<
ck_tile
::
bf16_t
>
{
static
constexpr
const
char
*
name
=
"bf16"
;
};
template
<
>
struct
t2s
<
ck_tile
::
bf16_t
>
{
static
constexpr
const
char
*
name
=
"bf16"
;
};
template
<
>
struct
t2s
<
ck_tile
::
fp8_t
>
{
static
constexpr
const
char
*
name
=
"fp8"
;
};
template
<
>
struct
t2s
<
ck_tile
::
fp8_t
>
{
static
constexpr
const
char
*
name
=
"fp8"
;
};
template
<
>
struct
t2s
<
ck_tile
::
bf8_t
>
{
static
constexpr
const
char
*
name
=
"bf8"
;
};
template
<
>
struct
t2s
<
ck_tile
::
bf8_t
>
{
static
constexpr
const
char
*
name
=
"bf8"
;
};
template
<
>
struct
t2s
<
ck_tile
::
int8_t
>
{
static
constexpr
const
char
*
name
=
"i8"
;
};
// clang-format on
// clang-format on
// in byte
// in byte
...
@@ -118,7 +119,7 @@ struct MoeSmoothquant
...
@@ -118,7 +119,7 @@ struct MoeSmoothquant
#define _SS_ std::string
#define _SS_ std::string
#define _TS_ std::to_string
#define _TS_ std::to_string
return
_SS_
(
"moe_smoothquant_"
)
+
_SS_
(
t2s
<
XDataType
>::
name
)
+
"_"
+
return
_SS_
(
"moe_smoothquant_"
)
+
_SS_
(
t2s
<
XDataType
>::
name
)
+
"_"
+
_SS_
(
t2s
<
QYDataType
>::
name
)
+
"_"
+
_TS_
(
S_
::
Block_M
)
+
"x"
+
_TS_
(
S_
::
Block_N
)
+
"_"
+
_TS_
(
S_
::
WarpPerBlock_M
)
+
"x"
+
_TS_
(
S_
::
WarpPerBlock_N
)
+
"_"
+
_TS_
(
S_
::
Block_M
)
+
"x"
+
_TS_
(
S_
::
Block_N
)
+
"_"
+
_TS_
(
S_
::
WarpPerBlock_M
)
+
"x"
+
_TS_
(
S_
::
WarpPerBlock_N
)
+
"_"
+
_TS_
(
S_
::
Warp_M
)
+
"x"
+
_TS_
(
S_
::
Warp_N
)
+
"_"
+
_TS_
(
S_
::
Vector_M
)
+
"x"
+
_TS_
(
S_
::
Vector_N
)
+
"_"
+
_TS_
(
S_
::
Warp_M
)
+
"x"
+
_TS_
(
S_
::
Warp_N
)
+
"_"
+
_TS_
(
S_
::
Vector_M
)
+
"x"
+
_TS_
(
S_
::
Vector_N
)
+
"_"
+
_SS_
(
Pipeline
::
name
)
+
surfix
;
_SS_
(
Pipeline
::
name
)
+
surfix
;
...
@@ -153,9 +154,10 @@ struct MoeSmoothquant
...
@@ -153,9 +154,10 @@ struct MoeSmoothquant
}();
}();
// [experts, hidden_size],
// [experts, hidden_size],
const
auto
x
scale_window
=
[
&
]()
{
const
auto
sm
scale_window
=
[
&
]()
{
const
auto
tmp_
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
const
auto
tmp_
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
static_cast
<
const
XScaleDataType
*>
(
kargs
.
p_xscale
)
+
i_expert
*
kargs
.
hidden_size
,
static_cast
<
const
SmoothScaleDataType
*>
(
kargs
.
p_smscale
)
+
i_expert
*
kargs
.
hidden_size
,
make_tuple
(
kargs
.
hidden_size
),
make_tuple
(
kargs
.
hidden_size
),
make_tuple
(
1
),
make_tuple
(
1
),
number
<
Vector_N
>
{},
number
<
Vector_N
>
{},
...
@@ -198,7 +200,7 @@ struct MoeSmoothquant
...
@@ -198,7 +200,7 @@ struct MoeSmoothquant
__shared__
char
smem
[
GetSmemSize
()];
__shared__
char
smem
[
GetSmemSize
()];
Pipeline
{}(
x_window
,
x
scale_window
,
yscale_window
,
qy_window
,
kargs
.
hidden_size
,
smem
);
Pipeline
{}(
x_window
,
sm
scale_window
,
yscale_window
,
qy_window
,
kargs
.
hidden_size
,
smem
);
}
}
};
};
...
...
include/ck_tile/ops/smoothquant/kernel/smoothquant_kernel.hpp
View file @
a7ae4f8e
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
5
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
...
@@ -11,11 +11,11 @@ namespace ck_tile {
...
@@ -11,11 +11,11 @@ namespace ck_tile {
// host side args
// host side args
struct
SmoothquantHostArgs
struct
SmoothquantHostArgs
{
{
const
void
*
p_x
;
// [m ,n], input, fp16/bf16
const
void
*
p_x
;
// [m ,n], input, fp16/bf16
const
void
*
p_
x
scale
;
// [1, n], input, columnwise scale, fp32
const
void
*
p_
sm
scale
;
// [1, n], input, columnwise scale, fp32
void
*
p_yscale
;
// [m, 1], output, rowwise quant scale (amax / 127) of (p_x * p_
x
scale)
void
*
p_yscale
;
// [m, 1], output, rowwise quant scale (amax / 127) of (p_x * p_
sm
scale)
void
*
p_qy
;
// [m, n], output, p_x * p_
x
scale / p_yscale
void
*
p_qy
;
// [m, n], output, p_x * p_
sm
scale / p_yscale
index_t
m
;
index_t
m
;
index_t
n
;
index_t
n
;
...
@@ -30,11 +30,11 @@ struct Smoothquant
...
@@ -30,11 +30,11 @@ struct Smoothquant
using
Pipeline
=
remove_cvref_t
<
Pipeline_
>
;
using
Pipeline
=
remove_cvref_t
<
Pipeline_
>
;
using
Problem
=
typename
Pipeline
::
Problem
;
using
Problem
=
typename
Pipeline
::
Problem
;
using
XDataType
=
remove_cvref_t
<
typename
Problem
::
XDataType
>
;
using
XDataType
=
remove_cvref_t
<
typename
Problem
::
XDataType
>
;
using
X
ScaleDataType
=
remove_cvref_t
<
typename
Problem
::
X
ScaleDataType
>
;
using
Smooth
ScaleDataType
=
remove_cvref_t
<
typename
Problem
::
Smooth
ScaleDataType
>
;
using
ComputeDataType
=
remove_cvref_t
<
typename
Problem
::
ComputeDataType
>
;
using
ComputeDataType
=
remove_cvref_t
<
typename
Problem
::
ComputeDataType
>
;
using
YScaleDataType
=
remove_cvref_t
<
typename
Problem
::
YScaleDataType
>
;
using
YScaleDataType
=
remove_cvref_t
<
typename
Problem
::
YScaleDataType
>
;
using
QYDataType
=
remove_cvref_t
<
typename
Problem
::
QYDataType
>
;
using
QYDataType
=
remove_cvref_t
<
typename
Problem
::
QYDataType
>
;
static
constexpr
index_t
Block_M
=
Problem
::
BlockShape
::
Block_M
;
static
constexpr
index_t
Block_M
=
Problem
::
BlockShape
::
Block_M
;
static
constexpr
index_t
Block_N
=
Problem
::
BlockShape
::
Block_N
;
static
constexpr
index_t
Block_N
=
Problem
::
BlockShape
::
Block_N
;
...
@@ -52,7 +52,7 @@ struct Smoothquant
...
@@ -52,7 +52,7 @@ struct Smoothquant
struct
Kargs
struct
Kargs
{
{
const
void
*
p_x
;
const
void
*
p_x
;
const
void
*
p_
x
scale
;
const
void
*
p_
sm
scale
;
void
*
p_yscale
;
void
*
p_yscale
;
void
*
p_qy
;
void
*
p_qy
;
...
@@ -67,7 +67,7 @@ struct Smoothquant
...
@@ -67,7 +67,7 @@ struct Smoothquant
CK_TILE_HOST
static
constexpr
Kargs
MakeKargs
(
const
Hargs
&
hargs
)
CK_TILE_HOST
static
constexpr
Kargs
MakeKargs
(
const
Hargs
&
hargs
)
{
{
return
Kargs
{
hargs
.
p_x
,
return
Kargs
{
hargs
.
p_x
,
hargs
.
p_
x
scale
,
hargs
.
p_
sm
scale
,
hargs
.
p_yscale
,
hargs
.
p_yscale
,
hargs
.
p_qy
,
hargs
.
p_qy
,
hargs
.
m
,
hargs
.
m
,
...
@@ -134,9 +134,9 @@ struct Smoothquant
...
@@ -134,9 +134,9 @@ struct Smoothquant
tmp2_
,
make_tuple
(
number
<
Block_M
>
{},
number
<
Block_N
>
{}),
{
iM
,
0
});
tmp2_
,
make_tuple
(
number
<
Block_M
>
{},
number
<
Block_N
>
{}),
{
iM
,
0
});
}();
}();
const
auto
x
scale_window
=
[
&
]()
{
const
auto
sm
scale_window
=
[
&
]()
{
const
auto
tmp_
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
const
auto
tmp_
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
static_cast
<
const
X
ScaleDataType
*>
(
kargs
.
p_
x
scale
),
static_cast
<
const
Smooth
ScaleDataType
*>
(
kargs
.
p_
sm
scale
),
make_tuple
(
kargs
.
n
),
make_tuple
(
kargs
.
n
),
make_tuple
(
1
),
make_tuple
(
1
),
number
<
Vector_N
>
{},
number
<
Vector_N
>
{},
...
@@ -177,7 +177,7 @@ struct Smoothquant
...
@@ -177,7 +177,7 @@ struct Smoothquant
__shared__
char
smem
[
GetSmemSize
()];
__shared__
char
smem
[
GetSmemSize
()];
Pipeline
{}(
x_window
,
x
scale_window
,
yscale_window
,
qy_window
,
kargs
.
n
,
smem
);
Pipeline
{}(
x_window
,
sm
scale_window
,
yscale_window
,
qy_window
,
kargs
.
n
,
smem
);
}
}
};
};
...
...
include/ck_tile/ops/smoothquant/pipeline/smoothquant_pipeline_default_policy.hpp
View file @
a7ae4f8e
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
5
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
...
@@ -28,7 +28,7 @@ struct SmoothquantPipelineDefaultPolicy
...
@@ -28,7 +28,7 @@ struct SmoothquantPipelineDefaultPolicy
}
}
template
<
typename
Problem
>
template
<
typename
Problem
>
CK_TILE_DEVICE
static
constexpr
auto
Make
X
ScaleBlockTileDistribution
()
CK_TILE_DEVICE
static
constexpr
auto
Make
Smooth
ScaleBlockTileDistribution
()
{
{
using
S
=
typename
Problem
::
BlockShape
;
using
S
=
typename
Problem
::
BlockShape
;
...
...
include/ck_tile/ops/smoothquant/pipeline/smoothquant_pipeline_one_pass.hpp
View file @
a7ae4f8e
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
5
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
...
@@ -16,11 +16,11 @@ struct SmoothquantPipelineOnePass
...
@@ -16,11 +16,11 @@ struct SmoothquantPipelineOnePass
using
Problem
=
ck_tile
::
remove_cvref_t
<
Problem_
>
;
using
Problem
=
ck_tile
::
remove_cvref_t
<
Problem_
>
;
using
Policy
=
ck_tile
::
remove_cvref_t
<
Policy_
>
;
using
Policy
=
ck_tile
::
remove_cvref_t
<
Policy_
>
;
using
XDataType
=
ck_tile
::
remove_cvref_t
<
typename
Problem
::
XDataType
>
;
using
XDataType
=
ck_tile
::
remove_cvref_t
<
typename
Problem
::
XDataType
>
;
using
X
ScaleDataType
=
ck_tile
::
remove_cvref_t
<
typename
Problem
::
X
ScaleDataType
>
;
using
Smooth
ScaleDataType
=
ck_tile
::
remove_cvref_t
<
typename
Problem
::
Smooth
ScaleDataType
>
;
using
ComputeDataType
=
ck_tile
::
remove_cvref_t
<
typename
Problem
::
ComputeDataType
>
;
using
ComputeDataType
=
ck_tile
::
remove_cvref_t
<
typename
Problem
::
ComputeDataType
>
;
using
QYDataType
=
ck_tile
::
remove_cvref_t
<
typename
Problem
::
QYDataType
>
;
using
QYDataType
=
ck_tile
::
remove_cvref_t
<
typename
Problem
::
QYDataType
>
;
using
YScaleDataType
=
ck_tile
::
remove_cvref_t
<
typename
Problem
::
YScaleDataType
>
;
using
YScaleDataType
=
ck_tile
::
remove_cvref_t
<
typename
Problem
::
YScaleDataType
>
;
static
constexpr
bool
kNeedCrossWarpSync
=
Problem
::
kNeedCrossWarpSync
;
static
constexpr
bool
kNeedCrossWarpSync
=
Problem
::
kNeedCrossWarpSync
;
static
constexpr
bool
kPadM
=
false
;
// TODO - BlockSmoothquantProblem::kPadM
static
constexpr
bool
kPadM
=
false
;
// TODO - BlockSmoothquantProblem::kPadM
...
@@ -39,9 +39,12 @@ struct SmoothquantPipelineOnePass
...
@@ -39,9 +39,12 @@ struct SmoothquantPipelineOnePass
return
Policy
::
template
GetSmemSize
<
Problem
>();
return
Policy
::
template
GetSmemSize
<
Problem
>();
}
}
template
<
typename
XWindow
,
typename
XScaleWindow
,
typename
QYWindow
,
typename
YScaleWindow
>
template
<
typename
XWindow
,
typename
SmoothScaleWindow
,
typename
QYWindow
,
typename
YScaleWindow
>
CK_TILE_DEVICE
auto
operator
()(
const
XWindow
&
x_window_
,
CK_TILE_DEVICE
auto
operator
()(
const
XWindow
&
x_window_
,
const
X
ScaleWindow
&
x
scale_window_
,
const
Smooth
ScaleWindow
&
sm
scale_window_
,
YScaleWindow
&
yscale_window
,
YScaleWindow
&
yscale_window
,
QYWindow
&
qy_window
,
QYWindow
&
qy_window
,
ck_tile
::
index_t
,
ck_tile
::
index_t
,
...
@@ -49,8 +52,8 @@ struct SmoothquantPipelineOnePass
...
@@ -49,8 +52,8 @@ struct SmoothquantPipelineOnePass
{
{
auto
x_window
=
auto
x_window
=
make_tile_window
(
x_window_
,
Policy
::
template
MakeXBlockTileDistribution
<
Problem
>());
make_tile_window
(
x_window_
,
Policy
::
template
MakeXBlockTileDistribution
<
Problem
>());
auto
x
scale_window
=
make_tile_window
(
auto
sm
scale_window
=
make_tile_window
(
x
scale_window_
,
Policy
::
template
Make
X
ScaleBlockTileDistribution
<
Problem
>());
sm
scale_window_
,
Policy
::
template
Make
Smooth
ScaleBlockTileDistribution
<
Problem
>());
auto
reduce_absmax_func
=
ReduceOp
::
AbsMax
{};
auto
reduce_absmax_func
=
ReduceOp
::
AbsMax
{};
auto
reduce_absmax3_func
=
[](
auto
acc_
,
auto
v_0_
,
auto
v_1_
)
{
auto
reduce_absmax3_func
=
[](
auto
acc_
,
auto
v_0_
,
auto
v_1_
)
{
...
@@ -67,14 +70,14 @@ struct SmoothquantPipelineOnePass
...
@@ -67,14 +70,14 @@ struct SmoothquantPipelineOnePass
auto
block_reduce2d_cross_warp_sync
=
auto
block_reduce2d_cross_warp_sync
=
Policy
::
template
GetBlockReduce2dCrossWarpSync
<
Problem
>();
Policy
::
template
GetBlockReduce2dCrossWarpSync
<
Problem
>();
const
auto
x
=
load_tile
(
x_window
);
const
auto
x
=
load_tile
(
x_window
);
const
auto
x
scale
=
load_tile
(
x
scale_window
);
const
auto
sm
scale
=
load_tile
(
sm
scale_window
);
auto
y
=
tile_elementwise_in
(
auto
y
=
tile_elementwise_in
(
[
&
](
const
auto
&
a
,
const
auto
&
b
)
{
[
&
](
const
auto
&
a
,
const
auto
&
b
)
{
return
type_convert
<
ComputeDataType
>
(
a
)
*
type_convert
<
ComputeDataType
>
(
b
);
return
type_convert
<
ComputeDataType
>
(
a
)
*
type_convert
<
ComputeDataType
>
(
b
);
},
},
x
,
x
,
x
scale
);
sm
scale
);
// compute absmax, cross-lane->cross-warp
// compute absmax, cross-lane->cross-warp
auto
absmax
=
[
&
]()
{
auto
absmax
=
[
&
]()
{
...
@@ -110,7 +113,7 @@ struct SmoothquantPipelineOnePass
...
@@ -110,7 +113,7 @@ struct SmoothquantPipelineOnePass
sweep_tile
(
qy
,
[
&
](
auto
idx
)
{
sweep_tile
(
qy
,
[
&
](
auto
idx
)
{
constexpr
auto
i_idx
=
make_tuple
(
idx
[
number
<
0
>
{}]);
constexpr
auto
i_idx
=
make_tuple
(
idx
[
number
<
0
>
{}]);
auto
qy_
=
y
[
idx
]
/
yscale
[
i_idx
];
auto
qy_
=
y
[
idx
]
/
yscale
[
i_idx
];
qy
(
idx
)
=
saturates
<
QYDataType
>
{}(
qy_
);
qy
(
idx
)
=
type_convert
<
QYDataType
>
(
saturates
<
QYDataType
>
{}(
qy_
)
)
;
});
});
store_tile
(
qy_window
,
qy
);
store_tile
(
qy_window
,
qy
);
}
}
...
...
include/ck_tile/ops/smoothquant/pipeline/smoothquant_pipeline_problem.hpp
View file @
a7ae4f8e
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
5
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
...
@@ -7,9 +7,9 @@
...
@@ -7,9 +7,9 @@
namespace
ck_tile
{
namespace
ck_tile
{
// Y = X *
X
Scale, QY = RowwiseDynamicQuant(Y) = SaturateCast(Y / YScale)
// Y = X *
Smooth
Scale, QY = RowwiseDynamicQuant(Y) = SaturateCast(Y / YScale)
template
<
typename
XDataType_
,
template
<
typename
XDataType_
,
typename
X
ScaleDataType_
,
typename
Smooth
ScaleDataType_
,
typename
ComputeDataType_
,
typename
ComputeDataType_
,
typename
YScaleDataType_
,
typename
YScaleDataType_
,
typename
QYDataType_
,
typename
QYDataType_
,
...
@@ -18,12 +18,12 @@ template <typename XDataType_,
...
@@ -18,12 +18,12 @@ template <typename XDataType_,
bool
kTwoPass_
>
bool
kTwoPass_
>
struct
SmoothquantPipelineProblem
struct
SmoothquantPipelineProblem
{
{
using
XDataType
=
remove_cvref_t
<
XDataType_
>
;
using
XDataType
=
remove_cvref_t
<
XDataType_
>
;
using
X
ScaleDataType
=
remove_cvref_t
<
X
ScaleDataType_
>
;
using
Smooth
ScaleDataType
=
remove_cvref_t
<
Smooth
ScaleDataType_
>
;
using
ComputeDataType
=
remove_cvref_t
<
ComputeDataType_
>
;
using
ComputeDataType
=
remove_cvref_t
<
ComputeDataType_
>
;
using
YScaleDataType
=
remove_cvref_t
<
YScaleDataType_
>
;
using
YScaleDataType
=
remove_cvref_t
<
YScaleDataType_
>
;
using
QYDataType
=
remove_cvref_t
<
QYDataType_
>
;
using
QYDataType
=
remove_cvref_t
<
QYDataType_
>
;
using
BlockShape
=
remove_cvref_t
<
BlockShape_
>
;
using
BlockShape
=
remove_cvref_t
<
BlockShape_
>
;
static
constexpr
bool
kNeedCrossLaneSync
=
BlockShape
::
ThreadPerWarp_N
>
1
;
static
constexpr
bool
kNeedCrossLaneSync
=
BlockShape
::
ThreadPerWarp_N
>
1
;
static
constexpr
bool
kNeedCrossWarpSync
=
BlockShape
::
WarpPerBlock_N
>
1
;
static
constexpr
bool
kNeedCrossWarpSync
=
BlockShape
::
WarpPerBlock_N
>
1
;
...
...
include/ck_tile/ops/smoothquant/pipeline/smoothquant_pipeline_two_pass.hpp
View file @
a7ae4f8e
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
5
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
...
@@ -16,11 +16,11 @@ struct SmoothquantPipelineTwoPass
...
@@ -16,11 +16,11 @@ struct SmoothquantPipelineTwoPass
using
Problem
=
ck_tile
::
remove_cvref_t
<
Problem_
>
;
using
Problem
=
ck_tile
::
remove_cvref_t
<
Problem_
>
;
using
Policy
=
ck_tile
::
remove_cvref_t
<
Policy_
>
;
using
Policy
=
ck_tile
::
remove_cvref_t
<
Policy_
>
;
using
XDataType
=
ck_tile
::
remove_cvref_t
<
typename
Problem
::
XDataType
>
;
using
XDataType
=
ck_tile
::
remove_cvref_t
<
typename
Problem
::
XDataType
>
;
using
X
ScaleDataType
=
ck_tile
::
remove_cvref_t
<
typename
Problem
::
X
ScaleDataType
>
;
using
Smooth
ScaleDataType
=
ck_tile
::
remove_cvref_t
<
typename
Problem
::
Smooth
ScaleDataType
>
;
using
ComputeDataType
=
ck_tile
::
remove_cvref_t
<
typename
Problem
::
ComputeDataType
>
;
using
ComputeDataType
=
ck_tile
::
remove_cvref_t
<
typename
Problem
::
ComputeDataType
>
;
using
QYDataType
=
ck_tile
::
remove_cvref_t
<
typename
Problem
::
QYDataType
>
;
using
QYDataType
=
ck_tile
::
remove_cvref_t
<
typename
Problem
::
QYDataType
>
;
using
YScaleDataType
=
ck_tile
::
remove_cvref_t
<
typename
Problem
::
YScaleDataType
>
;
using
YScaleDataType
=
ck_tile
::
remove_cvref_t
<
typename
Problem
::
YScaleDataType
>
;
static
constexpr
bool
kNeedCrossWarpSync
=
Problem
::
kNeedCrossWarpSync
;
static
constexpr
bool
kNeedCrossWarpSync
=
Problem
::
kNeedCrossWarpSync
;
static
constexpr
bool
kPadM
=
false
;
// TODO - BlockSmoothquantProblem::kPadM
static
constexpr
bool
kPadM
=
false
;
// TODO - BlockSmoothquantProblem::kPadM
...
@@ -39,9 +39,12 @@ struct SmoothquantPipelineTwoPass
...
@@ -39,9 +39,12 @@ struct SmoothquantPipelineTwoPass
return
Policy
::
template
GetSmemSize
<
Problem
>();
return
Policy
::
template
GetSmemSize
<
Problem
>();
}
}
template
<
typename
XWindow
,
typename
XScaleWindow
,
typename
QYWindow
,
typename
YScaleWindow
>
template
<
typename
XWindow
,
typename
SmoothScaleWindow
,
typename
QYWindow
,
typename
YScaleWindow
>
CK_TILE_DEVICE
auto
operator
()(
const
XWindow
&
x_window_
,
CK_TILE_DEVICE
auto
operator
()(
const
XWindow
&
x_window_
,
const
X
ScaleWindow
&
x
scale_window_
,
const
Smooth
ScaleWindow
&
sm
scale_window_
,
YScaleWindow
&
yscale_window
,
YScaleWindow
&
yscale_window
,
QYWindow
&
qy_window
,
QYWindow
&
qy_window
,
ck_tile
::
index_t
row_size
,
ck_tile
::
index_t
row_size
,
...
@@ -49,8 +52,8 @@ struct SmoothquantPipelineTwoPass
...
@@ -49,8 +52,8 @@ struct SmoothquantPipelineTwoPass
{
{
auto
x_window
=
auto
x_window
=
make_tile_window
(
x_window_
,
Policy
::
template
MakeXBlockTileDistribution
<
Problem
>());
make_tile_window
(
x_window_
,
Policy
::
template
MakeXBlockTileDistribution
<
Problem
>());
auto
x
scale_window
=
make_tile_window
(
auto
sm
scale_window
=
make_tile_window
(
x
scale_window_
,
Policy
::
template
Make
X
ScaleBlockTileDistribution
<
Problem
>());
sm
scale_window_
,
Policy
::
template
Make
Smooth
ScaleBlockTileDistribution
<
Problem
>());
static
constexpr
index_t
Block_N
=
Problem
::
BlockShape
::
Block_N
;
static
constexpr
index_t
Block_N
=
Problem
::
BlockShape
::
Block_N
;
index_t
num_n_tile_iteration
=
index_t
num_n_tile_iteration
=
...
@@ -76,14 +79,14 @@ struct SmoothquantPipelineTwoPass
...
@@ -76,14 +79,14 @@ struct SmoothquantPipelineTwoPass
for
(
int
iN
=
__builtin_amdgcn_readfirstlane
(
0
);
iN
<
num_n_tile_iteration
;
++
iN
)
for
(
int
iN
=
__builtin_amdgcn_readfirstlane
(
0
);
iN
<
num_n_tile_iteration
;
++
iN
)
{
{
const
auto
x
=
load_tile
(
x_window
);
const
auto
x
=
load_tile
(
x_window
);
const
auto
x
scale
=
load_tile
(
x
scale_window
);
const
auto
sm
scale
=
load_tile
(
sm
scale_window
);
const
auto
y
=
tile_elementwise_in
(
const
auto
y
=
tile_elementwise_in
(
[
&
](
const
auto
&
a
,
const
auto
&
b
)
{
[
&
](
const
auto
&
a
,
const
auto
&
b
)
{
return
type_convert
<
ComputeDataType
>
(
a
)
*
type_convert
<
ComputeDataType
>
(
b
);
return
type_convert
<
ComputeDataType
>
(
a
)
*
type_convert
<
ComputeDataType
>
(
b
);
},
},
x
,
x
,
x
scale
);
sm
scale
);
constexpr
auto
x_size_per_row
=
constexpr
auto
x_size_per_row
=
x
.
get_tile_distribution
().
get_ys_to_d_descriptor
().
get_lengths
().
at
(
number
<
1
>
{});
x
.
get_tile_distribution
().
get_ys_to_d_descriptor
().
get_lengths
().
at
(
number
<
1
>
{});
...
@@ -94,7 +97,7 @@ struct SmoothquantPipelineTwoPass
...
@@ -94,7 +97,7 @@ struct SmoothquantPipelineTwoPass
block_reduce2d
(
y
,
absmax
,
reduce_absmax_func
);
block_reduce2d
(
y
,
absmax
,
reduce_absmax_func
);
move_tile_window
(
x_window
,
{
0
,
Block_N
});
move_tile_window
(
x_window
,
{
0
,
Block_N
});
move_tile_window
(
x
scale_window
,
{
Block_N
});
move_tile_window
(
sm
scale_window
,
{
Block_N
});
}
}
// compute absmax, cross-lane->cross-warp
// compute absmax, cross-lane->cross-warp
...
@@ -114,31 +117,31 @@ struct SmoothquantPipelineTwoPass
...
@@ -114,31 +117,31 @@ struct SmoothquantPipelineTwoPass
row_size
%
Block_N
==
0
?
row_size
-
Block_N
:
row_size
-
row_size
%
Block_N
;
row_size
%
Block_N
==
0
?
row_size
-
Block_N
:
row_size
-
row_size
%
Block_N
;
move_tile_window
(
x_window
,
{
0
,
-
Block_N
});
move_tile_window
(
x_window
,
{
0
,
-
Block_N
});
move_tile_window
(
x
scale_window
,
{
-
Block_N
});
move_tile_window
(
sm
scale_window
,
{
-
Block_N
});
move_tile_window
(
qy_window
,
{
0
,
stride_to_right_most_window
});
move_tile_window
(
qy_window
,
{
0
,
stride_to_right_most_window
});
// recompute y and quantize y to qy
// recompute y and quantize y to qy
for
(
int
iN
=
__builtin_amdgcn_readfirstlane
(
0
);
iN
<
num_n_tile_iteration
;
++
iN
)
for
(
int
iN
=
__builtin_amdgcn_readfirstlane
(
0
);
iN
<
num_n_tile_iteration
;
++
iN
)
{
{
const
auto
x
=
load_tile
(
x_window
);
const
auto
x
=
load_tile
(
x_window
);
const
auto
x
scale
=
load_tile
(
x
scale_window
);
const
auto
sm
scale
=
load_tile
(
sm
scale_window
);
const
auto
y
=
tile_elementwise_in
(
const
auto
y
=
tile_elementwise_in
(
[
&
](
const
auto
&
a
,
const
auto
&
b
)
{
[
&
](
const
auto
&
a
,
const
auto
&
b
)
{
return
type_convert
<
ComputeDataType
>
(
a
)
*
type_convert
<
ComputeDataType
>
(
b
);
return
type_convert
<
ComputeDataType
>
(
a
)
*
type_convert
<
ComputeDataType
>
(
b
);
},
},
x
,
x
,
x
scale
);
sm
scale
);
auto
qy
=
make_static_distributed_tensor
<
QYDataType
>
(
y
.
get_tile_distribution
());
auto
qy
=
make_static_distributed_tensor
<
QYDataType
>
(
y
.
get_tile_distribution
());
sweep_tile
(
qy
,
[
&
](
auto
idx
)
{
sweep_tile
(
qy
,
[
&
](
auto
idx
)
{
constexpr
auto
i_idx
=
make_tuple
(
idx
[
number
<
0
>
{}]);
constexpr
auto
i_idx
=
make_tuple
(
idx
[
number
<
0
>
{}]);
auto
qy_
=
y
[
idx
]
/
yscale
[
i_idx
];
auto
qy_
=
y
[
idx
]
/
yscale
[
i_idx
];
qy
(
idx
)
=
saturates
<
QYDataType
>
{}(
qy_
);
qy
(
idx
)
=
type_convert
<
QYDataType
>
(
saturates
<
QYDataType
>
{}(
qy_
)
)
;
});
});
store_tile
(
qy_window
,
qy
);
store_tile
(
qy_window
,
qy
);
move_tile_window
(
x_window
,
{
0
,
-
Block_N
});
move_tile_window
(
x_window
,
{
0
,
-
Block_N
});
move_tile_window
(
x
scale_window
,
{
0
,
-
Block_N
});
move_tile_window
(
sm
scale_window
,
{
0
,
-
Block_N
});
move_tile_window
(
qy_window
,
{
0
,
-
Block_N
});
move_tile_window
(
qy_window
,
{
0
,
-
Block_N
});
}
}
}
}
...
...
library/include/ck/library/reference_tensor_operation/cpu/reference_fpAintB_gemm.hpp
View file @
a7ae4f8e
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
5
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
...
@@ -73,39 +73,9 @@ struct ReferencefpAintBGemm : public device::BaseOperator
...
@@ -73,39 +73,9 @@ struct ReferencefpAintBGemm : public device::BaseOperator
ScaleDataType
v_scale
;
ScaleDataType
v_scale
;
ADataType
v_converted_b
;
ADataType
v_converted_b
;
// use PassThrough instead of ConvertBF16RTN for reference calculation
arg
.
a_element_op_
(
v_a
,
arg
.
a_m_k_
(
m
,
k
));
if
constexpr
(
is_same_v
<
AElementwiseOperation
,
arg
.
b_element_op_
(
v_b
,
arg
.
b_k_n_
(
k
,
n
));
ck
::
tensor_operation
::
element_wise
::
ConvertBF16RTN
>
)
arg
.
b_element_op_
(
v_scale
,
arg
.
scale_k_n_
(
k
,
n
));
{
ck
::
tensor_operation
::
element_wise
::
PassThrough
{}(
v_a
,
arg
.
a_m_k_
(
m
,
k
));
}
else
{
arg
.
a_element_op_
(
v_a
,
arg
.
a_m_k_
(
m
,
k
));
}
// same for B matrix
if
constexpr
(
is_same_v
<
BElementwiseOperation
,
ck
::
tensor_operation
::
element_wise
::
ConvertBF16RTN
>
)
{
ck
::
tensor_operation
::
element_wise
::
PassThrough
{}(
v_b
,
arg
.
b_k_n_
(
k
,
n
));
}
else
{
arg
.
b_element_op_
(
v_b
,
arg
.
b_k_n_
(
k
,
n
));
}
// same for scale matrix
if
constexpr
(
is_same_v
<
BElementwiseOperation
,
ck
::
tensor_operation
::
element_wise
::
ConvertBF16RTN
>
)
{
ck
::
tensor_operation
::
element_wise
::
PassThrough
{}(
v_scale
,
arg
.
scale_k_n_
(
k
,
n
));
}
else
{
arg
.
b_element_op_
(
v_scale
,
arg
.
scale_k_n_
(
k
,
n
));
}
v_converted_b
=
type_convert
<
ADataType
>
(
v_b
)
*
v_scale
;
v_converted_b
=
type_convert
<
ADataType
>
(
v_b
)
*
v_scale
;
v_acc
+=
ck
::
type_convert
<
AccDataType
>
(
v_a
)
*
v_acc
+=
ck
::
type_convert
<
AccDataType
>
(
v_a
)
*
...
...
library/include/ck/library/reference_tensor_operation/cpu/reference_gemm.hpp
View file @
a7ae4f8e
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
5
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
...
@@ -68,13 +68,7 @@ struct ReferenceGemm : public device::BaseOperator
...
@@ -68,13 +68,7 @@ struct ReferenceGemm : public device::BaseOperator
for
(
int
k
=
0
;
k
<
K
;
++
k
)
for
(
int
k
=
0
;
k
<
K
;
++
k
)
{
{
// use PassThrough instead of ConvertBF16RTN for reference calculation
if
constexpr
(
is_same_v
<
ADataType
,
pk_i4_t
>
)
if
constexpr
(
is_same_v
<
AElementwiseOperation
,
ck
::
tensor_operation
::
element_wise
::
ConvertBF16RTN
>
)
{
ck
::
tensor_operation
::
element_wise
::
PassThrough
{}(
v_a
,
arg
.
a_m_k_
(
m
,
k
));
}
else
if
constexpr
(
is_same_v
<
ADataType
,
pk_i4_t
>
)
{
{
uint8_t
i4x2
=
arg
.
a_m_k_
(
m
,
k
).
data
;
uint8_t
i4x2
=
arg
.
a_m_k_
(
m
,
k
).
data
;
int8_t
i4
=
0
;
int8_t
i4
=
0
;
...
@@ -89,13 +83,8 @@ struct ReferenceGemm : public device::BaseOperator
...
@@ -89,13 +83,8 @@ struct ReferenceGemm : public device::BaseOperator
{
{
arg
.
a_element_op_
(
v_a
,
arg
.
a_m_k_
(
m
,
k
));
arg
.
a_element_op_
(
v_a
,
arg
.
a_m_k_
(
m
,
k
));
}
}
// same for B matrix
if
constexpr
(
is_same_v
<
BElementwiseOperation
,
if
constexpr
(
is_same_v
<
BDataType
,
pk_i4_t
>
)
ck
::
tensor_operation
::
element_wise
::
ConvertBF16RTN
>
)
{
ck
::
tensor_operation
::
element_wise
::
PassThrough
{}(
v_b
,
arg
.
b_k_n_
(
k
,
n
));
}
else
if
constexpr
(
is_same_v
<
BDataType
,
pk_i4_t
>
)
{
{
uint8_t
i4x2
=
arg
.
b_k_n_
(
k
,
n
).
data
;
uint8_t
i4x2
=
arg
.
b_k_n_
(
k
,
n
).
data
;
int8_t
i4
=
0
;
int8_t
i4
=
0
;
...
...
library/include/ck/library/reference_tensor_operation/cpu/reference_gemm_multiple_d.hpp
View file @
a7ae4f8e
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
5
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
...
@@ -74,26 +74,8 @@ struct ReferenceGemmMultipleD : public device::BaseOperator
...
@@ -74,26 +74,8 @@ struct ReferenceGemmMultipleD : public device::BaseOperator
for
(
int
k
=
0
;
k
<
K
;
++
k
)
for
(
int
k
=
0
;
k
<
K
;
++
k
)
{
{
// use PassThrough instead of ConvertBF16RTN for reference calculation
arg
.
a_element_op_
(
v_a
,
arg
.
a_m_k_
(
m
,
k
));
if
constexpr
(
is_same_v
<
AElementwiseOperation
,
arg
.
b_element_op_
(
v_b
,
arg
.
b_k_n_
(
k
,
n
));
ck
::
tensor_operation
::
element_wise
::
ConvertBF16RTN
>
)
{
ck
::
tensor_operation
::
element_wise
::
PassThrough
{}(
v_a
,
arg
.
a_m_k_
(
m
,
k
));
}
else
{
arg
.
a_element_op_
(
v_a
,
arg
.
a_m_k_
(
m
,
k
));
}
// same for B matrix
if
constexpr
(
is_same_v
<
BElementwiseOperation
,
ck
::
tensor_operation
::
element_wise
::
ConvertBF16RTN
>
)
{
ck
::
tensor_operation
::
element_wise
::
PassThrough
{}(
v_b
,
arg
.
b_k_n_
(
k
,
n
));
}
else
{
arg
.
b_element_op_
(
v_b
,
arg
.
b_k_n_
(
k
,
n
));
}
v_acc
+=
v_acc
+=
ck
::
type_convert
<
AccDataType
>
(
v_a
)
*
ck
::
type_convert
<
AccDataType
>
(
v_b
);
ck
::
type_convert
<
AccDataType
>
(
v_a
)
*
ck
::
type_convert
<
AccDataType
>
(
v_b
);
...
...
library/include/ck/library/tensor_operation_instance/gpu/grouped_gemm_fixed_nk.hpp
View file @
a7ae4f8e
...
@@ -126,6 +126,35 @@ void add_device_grouped_gemm_xdl_fixed_nk_bf16_i8_bf16_mk_nk_mn_instances(
...
@@ -126,6 +126,35 @@ void add_device_grouped_gemm_xdl_fixed_nk_bf16_i8_bf16_mk_nk_mn_instances(
PassThrough
>>>&
instances
);
PassThrough
>>>&
instances
);
#endif
#endif
// bf16_inputA bf16_inputB
#if defined(CK_ENABLE_BF16)
void
add_device_grouped_gemm_xdl_fixed_nk_bf16_bf16_bf16_mk_kn_mn_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedGemmFixedNK
<
Row
,
Row
,
Empty_Tuple
,
Row
,
BF16
,
BF16
,
Empty_Tuple
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_grouped_gemm_xdl_fixed_nk_bf16_bf16_bf16_mk_nk_mn_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedGemmFixedNK
<
Row
,
Col
,
Empty_Tuple
,
Row
,
BF16
,
BF16
,
Empty_Tuple
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
#endif // CK_ENABLE_BF16
template
<
typename
ALayout
,
template
<
typename
ALayout
,
typename
BLayout
,
typename
BLayout
,
typename
ELayout
,
typename
ELayout
,
...
@@ -227,6 +256,24 @@ struct DeviceOperationInstanceFactory<
...
@@ -227,6 +256,24 @@ struct DeviceOperationInstanceFactory<
}
}
#endif
#endif
// bf16_inputA bf16_inputB
#if defined(CK_ENABLE_BF16)
if
constexpr
(
is_same_v
<
ADataType
,
bhalf_t
>
&&
is_same_v
<
BDataType
,
bhalf_t
>
&&
is_same_v
<
EDataType
,
bhalf_t
>
)
{
if
constexpr
(
is_same_v
<
ALayout
,
Row
>
&&
is_same_v
<
BLayout
,
Row
>
&&
is_same_v
<
ELayout
,
Row
>
)
{
add_device_grouped_gemm_xdl_fixed_nk_bf16_bf16_bf16_mk_kn_mn_instances
(
op_ptrs
);
}
if
constexpr
(
is_same_v
<
ALayout
,
Row
>
&&
is_same_v
<
BLayout
,
Col
>
&&
is_same_v
<
ELayout
,
Row
>
)
{
add_device_grouped_gemm_xdl_fixed_nk_bf16_bf16_bf16_mk_nk_mn_instances
(
op_ptrs
);
}
}
#endif // CK_ENABLE_BF16
return
op_ptrs
;
return
op_ptrs
;
}
}
};
};
...
...
Prev
1
…
4
5
6
7
8
9
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment