Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
8b49f207
Unverified
Commit
8b49f207
authored
Jan 07, 2025
by
Max Podkorytov
Committed by
GitHub
Jan 07, 2025
Browse files
Merge branch 'develop' into fa-h512
parents
0d59f474
a6b761c3
Changes
262
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
466 additions
and
114 deletions
+466
-114
include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp
...e/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp
+274
-0
include/ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp
include/ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp
+4
-0
include/ck_tile/ops/image_to_column.hpp
include/ck_tile/ops/image_to_column.hpp
+1
-1
include/ck_tile/ops/layernorm2d.hpp
include/ck_tile/ops/layernorm2d.hpp
+1
-1
include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_default_policy.hpp
...rm2d/pipeline/layernorm2d_fwd_pipeline_default_policy.hpp
+30
-27
include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_one_pass.hpp
...ayernorm2d/pipeline/layernorm2d_fwd_pipeline_one_pass.hpp
+27
-13
include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_two_pass.hpp
...ayernorm2d/pipeline/layernorm2d_fwd_pipeline_two_pass.hpp
+11
-9
include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_traits.hpp
..._tile/ops/layernorm2d/pipeline/layernorm2d_fwd_traits.hpp
+2
-0
include/ck_tile/ops/norm_reduce.hpp
include/ck_tile/ops/norm_reduce.hpp
+10
-0
include/ck_tile/ops/norm_reduce/block/block_norm_reduce.hpp
include/ck_tile/ops/norm_reduce/block/block_norm_reduce.hpp
+78
-48
include/ck_tile/ops/norm_reduce/block/block_norm_reduce_problem.hpp
..._tile/ops/norm_reduce/block/block_norm_reduce_problem.hpp
+7
-2
include/ck_tile/ops/norm_reduce/thread/thread_welford.hpp
include/ck_tile/ops/norm_reduce/thread/thread_welford.hpp
+0
-0
include/ck_tile/ops/permute.hpp
include/ck_tile/ops/permute.hpp
+1
-1
include/ck_tile/ops/reduce.hpp
include/ck_tile/ops/reduce.hpp
+1
-1
include/ck_tile/ops/rmsnorm2d.hpp
include/ck_tile/ops/rmsnorm2d.hpp
+1
-1
include/ck_tile/ops/smoothquant.hpp
include/ck_tile/ops/smoothquant.hpp
+1
-1
include/ck_tile/ops/smoothquant/kernel/smoothquant_kernel.hpp
...ude/ck_tile/ops/smoothquant/kernel/smoothquant_kernel.hpp
+14
-6
include/ck_tile/ops/softmax.hpp
include/ck_tile/ops/softmax.hpp
+1
-1
include/ck_tile/ops/topk.hpp
include/ck_tile/ops/topk.hpp
+1
-1
include/ck_tile/ops/topk_softmax.hpp
include/ck_tile/ops/topk_softmax.hpp
+1
-1
No files found.
include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp
View file @
8b49f207
...
@@ -78,6 +78,9 @@ struct WarpGemmAttributeMfmaImplF16F16F32M32N32K8
...
@@ -78,6 +78,9 @@ struct WarpGemmAttributeMfmaImplF16F16F32M32N32K8
static
constexpr
index_t
kN
=
32
;
static
constexpr
index_t
kN
=
32
;
static
constexpr
index_t
kK
=
8
;
static
constexpr
index_t
kK
=
8
;
static
constexpr
index_t
kAMBlock
=
1
;
static
constexpr
index_t
kBNBlock
=
1
;
static
constexpr
index_t
kAMLane
=
32
;
static
constexpr
index_t
kAMLane
=
32
;
static
constexpr
index_t
kBNLane
=
32
;
static
constexpr
index_t
kBNLane
=
32
;
static
constexpr
index_t
kABKLane
=
2
;
static
constexpr
index_t
kABKLane
=
2
;
...
@@ -138,6 +141,9 @@ struct WarpGemmAttributeMfmaImplF16F16F32M16N16K16
...
@@ -138,6 +141,9 @@ struct WarpGemmAttributeMfmaImplF16F16F32M16N16K16
static
constexpr
index_t
kN
=
16
;
static
constexpr
index_t
kN
=
16
;
static
constexpr
index_t
kK
=
16
;
static
constexpr
index_t
kK
=
16
;
static
constexpr
index_t
kAMBlock
=
1
;
static
constexpr
index_t
kBNBlock
=
1
;
static
constexpr
index_t
kAMLane
=
16
;
static
constexpr
index_t
kAMLane
=
16
;
static
constexpr
index_t
kBNLane
=
16
;
static
constexpr
index_t
kBNLane
=
16
;
static
constexpr
index_t
kABKLane
=
4
;
static
constexpr
index_t
kABKLane
=
4
;
...
@@ -182,6 +188,134 @@ struct WarpGemmAttributeMfmaImplF16F16F32M16N16K16
...
@@ -182,6 +188,134 @@ struct WarpGemmAttributeMfmaImplF16F16F32M16N16K16
}
}
};
};
template
<
WGAttrCtlEnum
Ctrl_
=
WGAttrCtlEnum
::
Default_
>
struct
WarpGemmAttributeMfmaImplF16F16F32M4N64K4
{
static
constexpr
WGAttrCtlEnum
Ctrl
=
Ctrl_
;
using
ADataType
=
fp16_t
;
using
BDataType
=
fp16_t
;
using
CDataType
=
float
;
using
AVecType
=
ext_vector_t
<
fp16_t
,
4
>
;
using
BVecType
=
ext_vector_t
<
fp16_t
,
4
>
;
using
CVecType
=
ext_vector_t
<
float
,
4
>
;
static
constexpr
index_t
kM
=
4
;
static
constexpr
index_t
kN
=
64
;
static
constexpr
index_t
kK
=
4
;
static
constexpr
index_t
kAMBlock
=
1
;
static
constexpr
index_t
kBNBlock
=
16
;
// we only write down single block (4 threads) thread mapping here
static
constexpr
index_t
kAMLane
=
4
;
static
constexpr
index_t
kBNLane
=
4
;
static
constexpr
index_t
kABKLane
=
1
;
static
constexpr
index_t
kABKPerLane
=
4
;
static
constexpr
index_t
kCMLane
=
1
;
static
constexpr
index_t
kCNLane
=
4
;
static
constexpr
index_t
kCM0PerLane
=
1
;
static
constexpr
index_t
kCM1PerLane
=
4
;
// c_vec += a_vec * b_vec
template
<
bool
post_nop_
=
false
>
CK_TILE_DEVICE
void
operator
()(
CVecType
&
c_vec
,
const
AVecType
&
a_vec
,
const
BVecType
&
b_vec
,
bool_constant
<
post_nop_
>
=
{})
const
{
DISPATCH_MFMA_CTRL_
(
"v_mfma_f32_4x4x4f16"
,
Ctrl
)
else
{
#if defined(__gfx9__)
c_vec
=
__builtin_amdgcn_mfma_f32_4x4x4f16
(
a_vec
,
b_vec
,
c_vec
,
0
,
0
,
0
);
#else
ignore
=
c_vec
;
ignore
=
a_vec
;
ignore
=
b_vec
;
#endif
}
}
// c_vec = a_vec * b_vec
CK_TILE_DEVICE
CVecType
operator
()(
const
AVecType
&
a_vec
,
const
BVecType
&
b_vec
)
const
{
#if defined(__gfx9__)
return
bit_cast
<
CVecType
>
(
__builtin_amdgcn_mfma_f32_4x4x4f16
(
a_vec
,
b_vec
,
fp32x4_t
{
0.
f
},
0
,
0
,
0
));
#else
ignore
=
a_vec
;
ignore
=
b_vec
;
return
CVecType
{
0.
f
};
#endif
}
};
template
<
WGAttrCtlEnum
Ctrl_
=
WGAttrCtlEnum
::
Default_
>
struct
WarpGemmAttributeMfmaImplF16F16F32M64N4K4
{
static
constexpr
WGAttrCtlEnum
Ctrl
=
Ctrl_
;
using
ADataType
=
fp16_t
;
using
BDataType
=
fp16_t
;
using
CDataType
=
float
;
using
AVecType
=
ext_vector_t
<
fp16_t
,
4
>
;
using
BVecType
=
ext_vector_t
<
fp16_t
,
4
>
;
using
CVecType
=
ext_vector_t
<
float
,
4
>
;
static
constexpr
index_t
kM
=
64
;
static
constexpr
index_t
kN
=
4
;
static
constexpr
index_t
kK
=
4
;
static
constexpr
index_t
kAMBlock
=
16
;
static
constexpr
index_t
kBNBlock
=
1
;
// we only write down single block (4 threads) thread mapping here
static
constexpr
index_t
kAMLane
=
4
;
static
constexpr
index_t
kBNLane
=
4
;
static
constexpr
index_t
kABKLane
=
1
;
static
constexpr
index_t
kABKPerLane
=
4
;
static
constexpr
index_t
kCMLane
=
1
;
static
constexpr
index_t
kCNLane
=
4
;
static
constexpr
index_t
kCM0PerLane
=
1
;
static
constexpr
index_t
kCM1PerLane
=
4
;
// c_vec += a_vec * b_vec
template
<
bool
post_nop_
=
false
>
CK_TILE_DEVICE
void
operator
()(
CVecType
&
c_vec
,
const
AVecType
&
a_vec
,
const
BVecType
&
b_vec
,
bool_constant
<
post_nop_
>
=
{})
const
{
DISPATCH_MFMA_CTRL_
(
"v_mfma_f32_4x4x4f16"
,
Ctrl
)
else
{
#if defined(__gfx9__)
c_vec
=
__builtin_amdgcn_mfma_f32_4x4x4f16
(
a_vec
,
b_vec
,
c_vec
,
0
,
0
,
0
);
#else
ignore
=
c_vec
;
ignore
=
a_vec
;
ignore
=
b_vec
;
#endif
}
}
// c_vec = a_vec * b_vec
CK_TILE_DEVICE
CVecType
operator
()(
const
AVecType
&
a_vec
,
const
BVecType
&
b_vec
)
const
{
#if defined(__gfx9__)
return
bit_cast
<
CVecType
>
(
__builtin_amdgcn_mfma_f32_4x4x4f16
(
a_vec
,
b_vec
,
fp32x4_t
{
0.
f
},
0
,
0
,
0
));
#else
ignore
=
a_vec
;
ignore
=
b_vec
;
return
CVecType
{
0.
f
};
#endif
}
};
// Bf16
// Bf16
template
<
WGAttrCtlEnum
Ctrl_
=
WGAttrCtlEnum
::
Default_
>
template
<
WGAttrCtlEnum
Ctrl_
=
WGAttrCtlEnum
::
Default_
>
struct
WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8
struct
WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8
...
@@ -199,6 +333,9 @@ struct WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8
...
@@ -199,6 +333,9 @@ struct WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8
static
constexpr
index_t
kN
=
32
;
static
constexpr
index_t
kN
=
32
;
static
constexpr
index_t
kK
=
8
;
static
constexpr
index_t
kK
=
8
;
static
constexpr
index_t
kAMBlock
=
1
;
static
constexpr
index_t
kBNBlock
=
1
;
static
constexpr
index_t
kAMLane
=
32
;
static
constexpr
index_t
kAMLane
=
32
;
static
constexpr
index_t
kBNLane
=
32
;
static
constexpr
index_t
kBNLane
=
32
;
static
constexpr
index_t
kABKLane
=
2
;
static
constexpr
index_t
kABKLane
=
2
;
...
@@ -285,6 +422,9 @@ struct WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16
...
@@ -285,6 +422,9 @@ struct WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16
static
constexpr
index_t
kN
=
16
;
static
constexpr
index_t
kN
=
16
;
static
constexpr
index_t
kK
=
16
;
static
constexpr
index_t
kK
=
16
;
static
constexpr
index_t
kAMBlock
=
1
;
static
constexpr
index_t
kBNBlock
=
1
;
static
constexpr
index_t
kAMLane
=
16
;
static
constexpr
index_t
kAMLane
=
16
;
static
constexpr
index_t
kBNLane
=
16
;
static
constexpr
index_t
kBNLane
=
16
;
static
constexpr
index_t
kABKLane
=
4
;
static
constexpr
index_t
kABKLane
=
4
;
...
@@ -354,6 +494,134 @@ struct WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16
...
@@ -354,6 +494,134 @@ struct WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16
}
}
};
};
template
<
WGAttrCtlEnum
Ctrl_
=
WGAttrCtlEnum
::
Default_
>
struct
WarpGemmAttributeMfmaImplBf16Bf16F32M4N64K4
{
static
constexpr
WGAttrCtlEnum
Ctrl
=
Ctrl_
;
using
ADataType
=
bf16_t
;
using
BDataType
=
bf16_t
;
using
CDataType
=
float
;
using
AVecType
=
ext_vector_t
<
bf16_t
,
4
>
;
using
BVecType
=
ext_vector_t
<
bf16_t
,
4
>
;
using
CVecType
=
ext_vector_t
<
float
,
4
>
;
static
constexpr
index_t
kM
=
4
;
static
constexpr
index_t
kN
=
64
;
static
constexpr
index_t
kK
=
4
;
static
constexpr
index_t
kAMBlock
=
1
;
static
constexpr
index_t
kBNBlock
=
16
;
// we only write down single block (4 threads) thread mapping here
static
constexpr
index_t
kAMLane
=
4
;
static
constexpr
index_t
kBNLane
=
4
;
static
constexpr
index_t
kABKLane
=
1
;
static
constexpr
index_t
kABKPerLane
=
4
;
static
constexpr
index_t
kCMLane
=
1
;
static
constexpr
index_t
kCNLane
=
4
;
static
constexpr
index_t
kCM0PerLane
=
1
;
static
constexpr
index_t
kCM1PerLane
=
4
;
// c_vec += a_vec * b_vec
template
<
bool
post_nop_
=
false
>
CK_TILE_DEVICE
void
operator
()(
CVecType
&
c_vec
,
const
AVecType
&
a_vec
,
const
BVecType
&
b_vec
,
bool_constant
<
post_nop_
>
=
{})
const
{
DISPATCH_MFMA_CTRL_
(
"v_mfma_f32_4x4x4bf16_1k"
,
Ctrl
)
else
{
#if defined(__gfx9__)
c_vec
=
__builtin_amdgcn_mfma_f32_4x4x4bf16_1k
(
a_vec
,
b_vec
,
c_vec
,
0
,
0
,
0
);
#else
ignore
=
c_vec
;
ignore
=
a_vec
;
ignore
=
b_vec
;
#endif
}
}
// c_vec = a_vec * b_vec
CK_TILE_DEVICE
CVecType
operator
()(
const
AVecType
&
a_vec
,
const
BVecType
&
b_vec
)
const
{
#if defined(__gfx9__)
return
bit_cast
<
CVecType
>
(
__builtin_amdgcn_mfma_f32_4x4x4bf16_1k
(
a_vec
,
b_vec
,
fp32x4_t
{
0.
f
},
0
,
0
,
0
));
#else
ignore
=
a_vec
;
ignore
=
b_vec
;
return
CVecType
{
0.
f
};
#endif
}
};
template
<
WGAttrCtlEnum
Ctrl_
=
WGAttrCtlEnum
::
Default_
>
struct
WarpGemmAttributeMfmaImplBf16Bf16F32M64N4K4
{
static
constexpr
WGAttrCtlEnum
Ctrl
=
Ctrl_
;
using
ADataType
=
bf16_t
;
using
BDataType
=
bf16_t
;
using
CDataType
=
float
;
using
AVecType
=
ext_vector_t
<
bf16_t
,
4
>
;
using
BVecType
=
ext_vector_t
<
bf16_t
,
4
>
;
using
CVecType
=
ext_vector_t
<
float
,
4
>
;
static
constexpr
index_t
kM
=
64
;
static
constexpr
index_t
kN
=
4
;
static
constexpr
index_t
kK
=
4
;
static
constexpr
index_t
kAMBlock
=
16
;
static
constexpr
index_t
kBNBlock
=
1
;
// we only write down single block (4 threads) thread mapping here
static
constexpr
index_t
kAMLane
=
4
;
static
constexpr
index_t
kBNLane
=
4
;
static
constexpr
index_t
kABKLane
=
1
;
static
constexpr
index_t
kABKPerLane
=
4
;
static
constexpr
index_t
kCMLane
=
1
;
static
constexpr
index_t
kCNLane
=
4
;
static
constexpr
index_t
kCM0PerLane
=
1
;
static
constexpr
index_t
kCM1PerLane
=
4
;
// c_vec += a_vec * b_vec
template
<
bool
post_nop_
=
false
>
CK_TILE_DEVICE
void
operator
()(
CVecType
&
c_vec
,
const
AVecType
&
a_vec
,
const
BVecType
&
b_vec
,
bool_constant
<
post_nop_
>
=
{})
const
{
DISPATCH_MFMA_CTRL_
(
"v_mfma_f32_4x4x4bf16_1k"
,
Ctrl
)
else
{
#if defined(__gfx9__)
c_vec
=
__builtin_amdgcn_mfma_f32_4x4x4bf16_1k
(
a_vec
,
b_vec
,
c_vec
,
0
,
0
,
0
);
#else
ignore
=
c_vec
;
ignore
=
a_vec
;
ignore
=
b_vec
;
#endif
}
}
// c_vec = a_vec * b_vec
CK_TILE_DEVICE
CVecType
operator
()(
const
AVecType
&
a_vec
,
const
BVecType
&
b_vec
)
const
{
#if defined(__gfx9__)
return
bit_cast
<
CVecType
>
(
__builtin_amdgcn_mfma_f32_4x4x4bf16_1k
(
a_vec
,
b_vec
,
fp32x4_t
{
0.
f
},
0
,
0
,
0
));
#else
ignore
=
a_vec
;
ignore
=
b_vec
;
return
CVecType
{
0.
f
};
#endif
}
};
// FP8
// FP8
template
<
typename
AType_
,
typename
BType_
,
WGAttrCtlEnum
Ctrl_
=
WGAttrCtlEnum
::
Default_
>
template
<
typename
AType_
,
typename
BType_
,
WGAttrCtlEnum
Ctrl_
=
WGAttrCtlEnum
::
Default_
>
struct
WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base
struct
WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base
...
@@ -371,6 +639,9 @@ struct WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base
...
@@ -371,6 +639,9 @@ struct WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base
static
constexpr
index_t
kN
=
32
;
static
constexpr
index_t
kN
=
32
;
static
constexpr
index_t
kK
=
16
;
static
constexpr
index_t
kK
=
16
;
static
constexpr
index_t
kAMBlock
=
1
;
static
constexpr
index_t
kBNBlock
=
1
;
static
constexpr
index_t
kAMLane
=
32
;
static
constexpr
index_t
kAMLane
=
32
;
static
constexpr
index_t
kBNLane
=
32
;
static
constexpr
index_t
kBNLane
=
32
;
static
constexpr
index_t
kABKLane
=
2
;
static
constexpr
index_t
kABKLane
=
2
;
...
@@ -568,6 +839,9 @@ struct WarpGemmAttributeMfmaImpl_i32_32x32x16_i8
...
@@ -568,6 +839,9 @@ struct WarpGemmAttributeMfmaImpl_i32_32x32x16_i8
static
constexpr
index_t
kN
=
32
;
static
constexpr
index_t
kN
=
32
;
static
constexpr
index_t
kK
=
16
;
static
constexpr
index_t
kK
=
16
;
static
constexpr
index_t
kAMBlock
=
1
;
static
constexpr
index_t
kBNBlock
=
1
;
static
constexpr
index_t
kAMLane
=
32
;
static
constexpr
index_t
kAMLane
=
32
;
static
constexpr
index_t
kBNLane
=
32
;
static
constexpr
index_t
kBNLane
=
32
;
static
constexpr
index_t
kABKLane
=
2
;
static
constexpr
index_t
kABKLane
=
2
;
...
...
include/ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp
View file @
8b49f207
...
@@ -29,6 +29,8 @@ template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float
...
@@ -29,6 +29,8 @@ template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float
template
<
>
struct
WarpGemmMfmaDispatcher
<
ck_tile
::
half_t
,
ck_tile
::
half_t
,
float
,
16
,
16
,
16
,
true
>
{
using
Type
=
WarpGemmMfmaF16F16F32M16N16K16TransposedCDistribution
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
ck_tile
::
half_t
,
ck_tile
::
half_t
,
float
,
16
,
16
,
16
,
true
>
{
using
Type
=
WarpGemmMfmaF16F16F32M16N16K16TransposedCDistribution
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
ck_tile
::
half_t
,
ck_tile
::
half_t
,
float
,
16
,
16
,
32
,
false
>
{
using
Type
=
WarpGemmMfmaF16F16F32M16N16K32
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
ck_tile
::
half_t
,
ck_tile
::
half_t
,
float
,
16
,
16
,
32
,
false
>
{
using
Type
=
WarpGemmMfmaF16F16F32M16N16K32
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
ck_tile
::
half_t
,
ck_tile
::
half_t
,
float
,
16
,
16
,
32
,
true
>
{
using
Type
=
WarpGemmMfmaF16F16F32M16N16K32TransposedCDistribution
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
ck_tile
::
half_t
,
ck_tile
::
half_t
,
float
,
16
,
16
,
32
,
true
>
{
using
Type
=
WarpGemmMfmaF16F16F32M16N16K32TransposedCDistribution
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
ck_tile
::
half_t
,
ck_tile
::
half_t
,
float
,
4
,
64
,
16
,
false
>
{
using
Type
=
WarpGemmMfmaF16F16F32M4N64K16
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
ck_tile
::
half_t
,
ck_tile
::
half_t
,
float
,
64
,
4
,
16
,
false
>
{
using
Type
=
WarpGemmMfmaF16F16F32M64N4K16
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
ck_tile
::
half_t
,
ck_tile
::
half_t
,
float
,
32
,
32
,
8
,
false
,
true
>
{
using
Type
=
WarpGemmMfmaF16F16F32M32N32K8SwizzleA
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
ck_tile
::
half_t
,
ck_tile
::
half_t
,
float
,
32
,
32
,
8
,
false
,
true
>
{
using
Type
=
WarpGemmMfmaF16F16F32M32N32K8SwizzleA
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
ck_tile
::
half_t
,
ck_tile
::
half_t
,
float
,
32
,
32
,
16
,
false
,
true
>
{
using
Type
=
WarpGemmMfmaF16F16F32M32N32K16SwizzleA
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
ck_tile
::
half_t
,
ck_tile
::
half_t
,
float
,
32
,
32
,
16
,
false
,
true
>
{
using
Type
=
WarpGemmMfmaF16F16F32M32N32K16SwizzleA
;
};
...
@@ -42,6 +44,8 @@ template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float
...
@@ -42,6 +44,8 @@ template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float
template
<
>
struct
WarpGemmMfmaDispatcher
<
ck_tile
::
bf16_t
,
ck_tile
::
bf16_t
,
float
,
16
,
16
,
16
,
true
>
{
using
Type
=
WarpGemmMfmaBf16Bf16F32M16N16K16TransposedCDistribution
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
ck_tile
::
bf16_t
,
ck_tile
::
bf16_t
,
float
,
16
,
16
,
16
,
true
>
{
using
Type
=
WarpGemmMfmaBf16Bf16F32M16N16K16TransposedCDistribution
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
ck_tile
::
bf16_t
,
ck_tile
::
bf16_t
,
float
,
16
,
16
,
32
,
false
>
{
using
Type
=
WarpGemmMfmaBf16Bf16F32M16N16K32
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
ck_tile
::
bf16_t
,
ck_tile
::
bf16_t
,
float
,
16
,
16
,
32
,
false
>
{
using
Type
=
WarpGemmMfmaBf16Bf16F32M16N16K32
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
ck_tile
::
bf16_t
,
ck_tile
::
bf16_t
,
float
,
16
,
16
,
32
,
true
>
{
using
Type
=
WarpGemmMfmaBf16Bf16F32M16N16K32TransposedCDistribution
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
ck_tile
::
bf16_t
,
ck_tile
::
bf16_t
,
float
,
16
,
16
,
32
,
true
>
{
using
Type
=
WarpGemmMfmaBf16Bf16F32M16N16K32TransposedCDistribution
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
ck_tile
::
bf16_t
,
ck_tile
::
bf16_t
,
float
,
4
,
64
,
16
,
false
>
{
using
Type
=
WarpGemmMfmaBf16Bf16F32M4N64K16
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
ck_tile
::
bf16_t
,
ck_tile
::
bf16_t
,
float
,
64
,
4
,
16
,
false
>
{
using
Type
=
WarpGemmMfmaBf16Bf16F32M64N4K16
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
ck_tile
::
bf16_t
,
ck_tile
::
bf16_t
,
float
,
32
,
32
,
8
,
false
,
true
>
{
using
Type
=
WarpGemmMfmaBf16Bf16F32M32N32K8SwizzleA
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
ck_tile
::
bf16_t
,
ck_tile
::
bf16_t
,
float
,
32
,
32
,
8
,
false
,
true
>
{
using
Type
=
WarpGemmMfmaBf16Bf16F32M32N32K8SwizzleA
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
ck_tile
::
bf16_t
,
ck_tile
::
bf16_t
,
float
,
32
,
32
,
16
,
false
,
true
>
{
using
Type
=
WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleA
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
ck_tile
::
bf16_t
,
ck_tile
::
bf16_t
,
float
,
32
,
32
,
16
,
false
,
true
>
{
using
Type
=
WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleA
;
};
...
...
include/ck_tile/ops/image_to_column.hpp
View file @
8b49f207
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
5
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
...
...
include/ck_tile/ops/layernorm2d.hpp
View file @
8b49f207
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
5
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
...
...
include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_default_policy.hpp
View file @
8b49f207
...
@@ -4,8 +4,8 @@
...
@@ -4,8 +4,8 @@
#pragma once
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/core.hpp"
#include "ck_tile/ops/
welford
/block/block_
welford
_problem.hpp"
#include "ck_tile/ops/
norm_reduce
/block/block_
norm_reduce
_problem.hpp"
#include "ck_tile/ops/
welford
/block/block_
welford
.hpp"
#include "ck_tile/ops/
norm_reduce
/block/block_
norm_reduce
.hpp"
namespace
ck_tile
{
namespace
ck_tile
{
...
@@ -43,36 +43,38 @@ struct Layernorm2dFwdPipelineDefaultPolicy
...
@@ -43,36 +43,38 @@ struct Layernorm2dFwdPipelineDefaultPolicy
}
}
template
<
typename
Problem
>
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetBlock
Welford
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetBlock
NormReduce
()
{
{
using
P_
=
Block
Welford
Problem
<
typename
Problem
::
ComputeDataType
,
using
P_
=
Block
NormReduce
Problem
<
typename
Problem
::
ComputeDataType
,
typename
Problem
::
ComputeDataType
,
typename
Problem
::
ComputeDataType
,
typename
Problem
::
BlockShape
,
typename
Problem
::
BlockShape
,
Problem
::
Traits
::
kFastFDiv
>
;
Problem
::
Traits
::
kFastFDiv
,
Problem
::
Traits
::
kWelford
>
;
return
Block
Welford
<
P_
>
{};
return
Block
NormReduce
<
P_
>
{};
}
}
template
<
typename
Problem
>
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetBlock
Welford
Sync
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetBlock
NormReduce
Sync
()
{
{
using
P_
=
BlockWelfordProblem
<
typename
Problem
::
ComputeDataType
,
using
P_
=
BlockNormReduceProblem
<
typename
Problem
::
ComputeDataType
,
typename
Problem
::
ComputeDataType
,
typename
Problem
::
ComputeDataType
,
typename
Problem
::
BlockShape
,
typename
Problem
::
BlockShape
,
Problem
::
Traits
::
kFastFDiv
>
;
Problem
::
Traits
::
kFastFDiv
,
Problem
::
Traits
::
kWelford
>
;
return
Block
Welford
Sync
<
P_
>
{};
return
Block
NormReduce
Sync
<
P_
>
{};
}
}
template
<
typename
Problem
>
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetBlock
Welford
CrossWarpSync
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetBlock
NormReduce
CrossWarpSync
()
{
{
using
P_
=
BlockWelfordProblem
<
typename
Problem
::
ComputeDataType
,
using
P_
=
BlockNormReduceProblem
<
typename
Problem
::
ComputeDataType
,
typename
Problem
::
ComputeDataType
,
typename
Problem
::
ComputeDataType
,
typename
Problem
::
BlockShape
,
typename
Problem
::
BlockShape
,
Problem
::
Traits
::
kFastFDiv
>
;
Problem
::
Traits
::
kFastFDiv
,
Problem
::
Traits
::
kWelford
>
;
return
Block
Welford
CrossWarpSync
<
P_
>
{};
return
Block
NormReduce
CrossWarpSync
<
P_
>
{};
}
}
template
<
typename
Problem
>
template
<
typename
Problem
>
...
@@ -80,19 +82,20 @@ struct Layernorm2dFwdPipelineDefaultPolicy
...
@@ -80,19 +82,20 @@ struct Layernorm2dFwdPipelineDefaultPolicy
{
{
if
constexpr
(
Problem
::
kNeedCrossWarpSync
)
if
constexpr
(
Problem
::
kNeedCrossWarpSync
)
{
{
using
P_
=
BlockWelfordProblem
<
typename
Problem
::
ComputeDataType
,
using
P_
=
BlockNormReduceProblem
<
typename
Problem
::
ComputeDataType
,
typename
Problem
::
ComputeDataType
,
typename
Problem
::
ComputeDataType
,
typename
Problem
::
BlockShape
,
typename
Problem
::
BlockShape
,
Problem
::
Traits
::
kFastFDiv
>
;
Problem
::
Traits
::
kFastFDiv
,
Problem
::
Traits
::
kWelford
>
;
using
block_welford
=
Block
Welford
<
P_
>
;
using
block_welford
=
Block
NormReduce
<
P_
>
;
using
x_block_tile
=
using
x_block_tile
=
decltype
(
make_static_distributed_tensor
<
typename
Problem
::
ComputeDataType
>
(
decltype
(
make_static_distributed_tensor
<
typename
Problem
::
ComputeDataType
>
(
MakeXBlockTileDistribution
<
Problem
>
()));
MakeXBlockTileDistribution
<
Problem
>
()));
using
mean_var_block_tile
=
using
mean_var_block_tile
=
decltype
(
block_welford
::
template
MakeMeanVarBlockTile
<
x_block_tile
>());
decltype
(
block_welford
::
template
MakeMeanVarBlockTile
<
x_block_tile
>());
return
GetBlock
Welford
CrossWarpSync
<
Problem
>
()
return
GetBlock
NormReduce
CrossWarpSync
<
Problem
>
()
.
template
GetSmemSize
<
mean_var_block_tile
>();
.
template
GetSmemSize
<
mean_var_block_tile
>();
}
}
else
else
...
...
include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_one_pass.hpp
View file @
8b49f207
...
@@ -37,6 +37,7 @@ struct Layernorm2dFwdPipelineOnePass
...
@@ -37,6 +37,7 @@ struct Layernorm2dFwdPipelineOnePass
static
constexpr
bool
kPadM
=
false
;
// TODO - BlockLayernorm2dFwdProblem::kPadM
static
constexpr
bool
kPadM
=
false
;
// TODO - BlockLayernorm2dFwdProblem::kPadM
static
constexpr
bool
kPadN
=
Problem
::
Traits
::
kPadN
;
static
constexpr
bool
kPadN
=
Problem
::
Traits
::
kPadN
;
static
constexpr
bool
kFastFDiv
=
Problem
::
Traits
::
kFastFDiv
;
static
constexpr
bool
kFastFDiv
=
Problem
::
Traits
::
kFastFDiv
;
static
constexpr
bool
kWelford
=
Problem
::
Traits
::
kWelford
;
static
constexpr
auto
kFusedAdd
=
Problem
::
Traits
::
kFusedAdd
;
static
constexpr
auto
kFusedAdd
=
Problem
::
Traits
::
kFusedAdd
;
static
constexpr
auto
kFusedQuant
=
Problem
::
Traits
::
kFusedQuant
;
static
constexpr
auto
kFusedQuant
=
Problem
::
Traits
::
kFusedQuant
;
...
@@ -95,11 +96,16 @@ struct Layernorm2dFwdPipelineOnePass
...
@@ -95,11 +96,16 @@ struct Layernorm2dFwdPipelineOnePass
int
cur_count
=
0
;
int
cur_count
=
0
;
int
max_count
=
int
max_count
=
block_tile_welford_calculate_max_count
<
typename
Problem
::
BlockShape
>
(
row_size
);
block_tile_welford_calculate_max_count
<
typename
Problem
::
BlockShape
>
(
row_size
);
auto
block_welford
=
Policy
::
template
GetBlockWelford
<
Problem
>();
auto
block_norm_reduce
=
Policy
::
template
GetBlockNormReduce
<
Problem
>();
auto
block_welford_sync
=
Policy
::
template
GetBlockWelfordSync
<
Problem
>();
auto
block_norm_reduce_sync
=
Policy
::
template
GetBlockNormReduceSync
<
Problem
>();
auto
block_welford_cross_warp_sync
=
auto
block_norm_reduce_cross_warp_sync
=
Policy
::
template
GetBlockWelfordCrossWarpSync
<
Problem
>();
Policy
::
template
GetBlockNormReduceCrossWarpSync
<
Problem
>();
using
XTensorType
=
decltype
(
cast_tile
<
ComputeDataType
>
(
x
));
auto
mean
=
block_norm_reduce
.
template
MakeMeanVarBlockTile
<
XTensorType
>();
auto
var
=
block_norm_reduce
.
template
MakeMeanVarBlockTile
<
XTensorType
>();
clear_tile
(
mean
);
clear_tile
(
var
);
// load gamma/beta (TODO: support no gamma/beta?)
// load gamma/beta (TODO: support no gamma/beta?)
const
auto
gamma
=
load_tile
(
gamma_window
);
const
auto
gamma
=
load_tile
(
gamma_window
);
const
auto
beta
=
load_tile
(
beta_window
);
const
auto
beta
=
load_tile
(
beta_window
);
...
@@ -117,12 +123,21 @@ struct Layernorm2dFwdPipelineOnePass
...
@@ -117,12 +123,21 @@ struct Layernorm2dFwdPipelineOnePass
store_tile
(
y_residual_window
,
cast_tile
<
YResidualDataType
>
(
acc
));
store_tile
(
y_residual_window
,
cast_tile
<
YResidualDataType
>
(
acc
));
}
}
// compute welford each-thread->cross-lane->cross-warp
// compute reduce each-thread->cross-lane->cross-warp
auto
[
mean
,
var
]
=
block_welford
(
acc
,
cur_count
,
max_count
);
block_norm_reduce
(
acc
,
mean
,
var
,
cur_count
,
max_count
);
block_welford_sync
(
mean
,
var
,
cur_count
);
block_norm_reduce_sync
(
mean
,
var
,
cur_count
);
block_welford_cross_warp_sync
(
mean
,
var
,
cur_count
,
smem
);
block_norm_reduce_cross_warp_sync
(
mean
,
var
,
cur_count
,
smem
);
block_tile_welford_post_scale_var
(
var
,
cur_count
,
constant
<
kFastFDiv
>
{});
if
(
kWelford
)
{
block_tile_welford_post_scale_var
(
var
,
cur_count
,
constant
<
kFastFDiv
>
{});
}
else
{
sweep_tile
(
mean
,
[
&
](
auto
idx
)
{
mean
(
idx
)
=
mean
(
idx
)
/
type_convert
<
MeanDataType
>
(
row_size
);
var
(
idx
)
=
var
(
idx
)
/
type_convert
<
MeanDataType
>
(
row_size
)
-
mean
(
idx
)
*
mean
(
idx
);
});
}
// compute inv-std
// compute inv-std
auto
inv_std
=
tile_elementwise_in
(
auto
inv_std
=
tile_elementwise_in
(
[
&
](
const
auto
&
v_
)
{
[
&
](
const
auto
&
v_
)
{
...
@@ -153,8 +168,7 @@ struct Layernorm2dFwdPipelineOnePass
...
@@ -153,8 +168,7 @@ struct Layernorm2dFwdPipelineOnePass
const
auto
beta_
=
type_convert
<
ComputeDataType
>
(
beta
[
j_idx
]);
const
auto
beta_
=
type_convert
<
ComputeDataType
>
(
beta
[
j_idx
]);
auto
ln_
=
(
acc
[
idx
]
-
mean_
[
i_idx
])
*
inv_std
[
i_idx
]
*
gamma_
+
beta_
;
auto
ln_
=
(
acc
[
idx
]
-
mean_
[
i_idx
])
*
inv_std
[
i_idx
]
*
gamma_
+
beta_
;
ln
(
idx
)
=
ln_
;
ln
(
idx
)
=
ln_
;
});
});
if
constexpr
(
kFusedQuant
==
Layernorm2dFusedQuantEnum
::
DYNAMIC_QUANT
||
if
constexpr
(
kFusedQuant
==
Layernorm2dFusedQuantEnum
::
DYNAMIC_QUANT
||
...
...
include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_two_pass.hpp
View file @
8b49f207
...
@@ -36,6 +36,7 @@ struct Layernorm2dFwdPipelineTwoPass
...
@@ -36,6 +36,7 @@ struct Layernorm2dFwdPipelineTwoPass
static
constexpr
bool
kPadM
=
false
;
// TODO - BlockLayernorm2dFwdProblem::kPadM
static
constexpr
bool
kPadM
=
false
;
// TODO - BlockLayernorm2dFwdProblem::kPadM
static
constexpr
bool
kPadN
=
Problem
::
Traits
::
kPadN
;
static
constexpr
bool
kPadN
=
Problem
::
Traits
::
kPadN
;
static
constexpr
bool
kFastFDiv
=
Problem
::
Traits
::
kFastFDiv
;
static
constexpr
bool
kFastFDiv
=
Problem
::
Traits
::
kFastFDiv
;
static
constexpr
bool
kWelford
=
Problem
::
Traits
::
kWelford
;
static
constexpr
auto
kFusedAdd
=
Problem
::
Traits
::
kFusedAdd
;
static
constexpr
auto
kFusedAdd
=
Problem
::
Traits
::
kFusedAdd
;
static
constexpr
auto
kFusedQuant
=
Problem
::
Traits
::
kFusedQuant
;
static
constexpr
auto
kFusedQuant
=
Problem
::
Traits
::
kFusedQuant
;
...
@@ -77,6 +78,7 @@ struct Layernorm2dFwdPipelineTwoPass
...
@@ -77,6 +78,7 @@ struct Layernorm2dFwdPipelineTwoPass
void
*
smem
,
void
*
smem
,
Epilogue
)
const
Epilogue
)
const
{
{
static_assert
(
kWelford
==
true
,
"2 pass only supports welford merge"
);
auto
x_window
=
auto
x_window
=
make_tile_window
(
x_window_
,
Policy
::
template
MakeXBlockTileDistribution
<
Problem
>());
make_tile_window
(
x_window_
,
Policy
::
template
MakeXBlockTileDistribution
<
Problem
>());
auto
gamma_window
=
make_tile_window
(
auto
gamma_window
=
make_tile_window
(
...
@@ -102,14 +104,14 @@ struct Layernorm2dFwdPipelineTwoPass
...
@@ -102,14 +104,14 @@ struct Layernorm2dFwdPipelineTwoPass
int
max_count
=
int
max_count
=
(
num_n_tile_iteration
-
1
)
*
count_per_iter
+
(
num_n_tile_iteration
-
1
)
*
count_per_iter
+
block_tile_welford_calculate_max_count
<
typename
Problem
::
BlockShape
>
(
last_iter_n
);
block_tile_welford_calculate_max_count
<
typename
Problem
::
BlockShape
>
(
last_iter_n
);
auto
block_
welford
=
Policy
::
template
GetBlock
Welford
<
Problem
>();
auto
block_
norm_reduce
=
Policy
::
template
GetBlock
NormReduce
<
Problem
>();
auto
block_
welford
_sync
=
Policy
::
template
GetBlock
Welford
Sync
<
Problem
>();
auto
block_
norm_reduce
_sync
=
Policy
::
template
GetBlock
NormReduce
Sync
<
Problem
>();
auto
block_
welford
_cross_warp_sync
=
auto
block_
norm_reduce
_cross_warp_sync
=
Policy
::
template
GetBlock
Welford
CrossWarpSync
<
Problem
>();
Policy
::
template
GetBlock
NormReduce
CrossWarpSync
<
Problem
>();
using
XTensorType
=
decltype
(
cast_tile
<
ComputeDataType
>
(
load_tile
(
x_window
)));
using
XTensorType
=
decltype
(
cast_tile
<
ComputeDataType
>
(
load_tile
(
x_window
)));
auto
mean
=
block_
welford
.
template
MakeMeanVarBlockTile
<
XTensorType
>();
auto
mean
=
block_
norm_reduce
.
template
MakeMeanVarBlockTile
<
XTensorType
>();
auto
var
=
block_
welford
.
template
MakeMeanVarBlockTile
<
XTensorType
>();
auto
var
=
block_
norm_reduce
.
template
MakeMeanVarBlockTile
<
XTensorType
>();
for
(
int
iN
=
__builtin_amdgcn_readfirstlane
(
0
);
iN
<
num_n_tile_iteration
;
++
iN
)
for
(
int
iN
=
__builtin_amdgcn_readfirstlane
(
0
);
iN
<
num_n_tile_iteration
;
++
iN
)
{
{
...
@@ -133,11 +135,11 @@ struct Layernorm2dFwdPipelineTwoPass
...
@@ -133,11 +135,11 @@ struct Layernorm2dFwdPipelineTwoPass
move_tile_window
(
y_residual_window
,
{
0
,
Block_N
});
move_tile_window
(
y_residual_window
,
{
0
,
Block_N
});
}
}
}
}
block_
welford
(
acc
,
mean
,
var
,
cur_count
,
max_count
);
block_
norm_reduce
(
acc
,
mean
,
var
,
cur_count
,
max_count
);
}
}
block_
welford
_sync
(
mean
,
var
,
cur_count
);
block_
norm_reduce
_sync
(
mean
,
var
,
cur_count
);
block_
welford
_cross_warp_sync
(
mean
,
var
,
cur_count
,
smem
);
block_
norm_reduce
_cross_warp_sync
(
mean
,
var
,
cur_count
,
smem
);
block_tile_welford_post_scale_var
(
var
,
cur_count
,
constant
<
kFastFDiv
>
{});
block_tile_welford_post_scale_var
(
var
,
cur_count
,
constant
<
kFastFDiv
>
{});
// compute inv-std
// compute inv-std
...
...
include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_traits.hpp
View file @
8b49f207
...
@@ -40,6 +40,7 @@ template<> struct Layernorm2dFusedQuantEnumName<Layernorm2dFusedQuantEnum::SMOOT
...
@@ -40,6 +40,7 @@ template<> struct Layernorm2dFusedQuantEnumName<Layernorm2dFusedQuantEnum::SMOOT
template
<
bool
kPadN_
,
template
<
bool
kPadN_
,
bool
kSaveMeanInvStd_
,
bool
kSaveMeanInvStd_
,
bool
kFastFDiv_
,
bool
kFastFDiv_
,
bool
kWelford_
,
bool
kTwoPass_
,
bool
kTwoPass_
,
Layernorm2dFusedAddEnum
kFusedAdd_
,
Layernorm2dFusedAddEnum
kFusedAdd_
,
Layernorm2dFusedQuantEnum
kFusedQuant_
>
Layernorm2dFusedQuantEnum
kFusedQuant_
>
...
@@ -48,6 +49,7 @@ struct Layernorm2dFwdTraits
...
@@ -48,6 +49,7 @@ struct Layernorm2dFwdTraits
static
constexpr
bool
kPadN
=
kPadN_
;
static
constexpr
bool
kPadN
=
kPadN_
;
static
constexpr
bool
kSaveMeanInvStd
=
kSaveMeanInvStd_
;
static
constexpr
bool
kSaveMeanInvStd
=
kSaveMeanInvStd_
;
static
constexpr
bool
kFastFDiv
=
kFastFDiv_
;
static
constexpr
bool
kFastFDiv
=
kFastFDiv_
;
static
constexpr
bool
kWelford
=
kWelford_
;
static
constexpr
bool
kTwoPass
=
kTwoPass_
;
static
constexpr
bool
kTwoPass
=
kTwoPass_
;
static
constexpr
Layernorm2dFusedAddEnum
kFusedAdd
=
kFusedAdd_
;
static
constexpr
Layernorm2dFusedAddEnum
kFusedAdd
=
kFusedAdd_
;
static
constexpr
Layernorm2dFusedQuantEnum
kFusedQuant
=
kFusedQuant_
;
static
constexpr
Layernorm2dFusedQuantEnum
kFusedQuant
=
kFusedQuant_
;
...
...
include/ck_tile/ops/
welford
.hpp
→
include/ck_tile/ops/
norm_reduce
.hpp
View file @
8b49f207
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
5
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
#include "ck_tile/ops/
welford
/block/block_
welford
.hpp"
#include "ck_tile/ops/
norm_reduce
/block/block_
norm_reduce
.hpp"
#include "ck_tile/ops/
welford
/block/block_
welford
_problem.hpp"
#include "ck_tile/ops/
norm_reduce
/block/block_
norm_reduce
_problem.hpp"
#include "ck_tile/ops/
welford
/thread/thread_welford.hpp"
#include "ck_tile/ops/
norm_reduce
/thread/thread_welford.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
include/ck_tile/ops/
welford
/block/block_
welford
.hpp
→
include/ck_tile/ops/
norm_reduce
/block/block_
norm_reduce
.hpp
View file @
8b49f207
...
@@ -4,22 +4,23 @@
...
@@ -4,22 +4,23 @@
#pragma once
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/core.hpp"
#include "ck_tile/ops/
welford
/thread/thread_welford.hpp"
#include "ck_tile/ops/
norm_reduce
/thread/thread_welford.hpp"
namespace
ck_tile
{
namespace
ck_tile
{
template
<
typename
Problem_
,
typename
Policy_
=
void
>
template
<
typename
Problem_
,
typename
Policy_
=
void
>
struct
Block
Welford
struct
Block
NormReduce
{
{
using
Problem
=
remove_cvref_t
<
Problem_
>
;
using
Problem
=
remove_cvref_t
<
Problem_
>
;
using
XDataType
=
typename
Problem
::
XDataType
;
using
XDataType
=
typename
Problem
::
XDataType
;
using
ComputeDataType
=
typename
Problem
::
ComputeDataType
;
using
ComputeDataType
=
typename
Problem
::
ComputeDataType
;
static
constexpr
bool
kFastFDiv
=
Problem
::
kFastFDiv
;
static
constexpr
bool
kFastFDiv
=
Problem
::
kFastFDiv
;
static
constexpr
bool
kWelford
=
Problem
::
kWelford
;
CK_TILE_DEVICE
constexpr
Block
Welford
()
{}
CK_TILE_DEVICE
constexpr
Block
NormReduce
()
{}
// [CAUSION] - max_count_ is to deal with the padding problem
// [CAUSION] - max_count_ is to deal with the padding problem
// max_count_ is depend on caller, eg: naive and splitN
welford
will have different
// max_count_ is depend on caller, eg: naive and splitN
norm_reduce
will have different
// calculation of max_count_
// calculation of max_count_
// -> use block_welford_calculate_max_count to compute
// -> use block_welford_calculate_max_count to compute
template
<
typename
XDistributedTensor_
,
template
<
typename
XDistributedTensor_
,
...
@@ -40,18 +41,24 @@ struct BlockWelford
...
@@ -40,18 +41,24 @@ struct BlockWelford
if
(
cur_count_
<
max_count_
)
if
(
cur_count_
<
max_count_
)
{
{
++
cur_count_
;
++
cur_count_
;
sweep_tile_span
(
spans
[
I0
],
[
&
](
auto
dstr_idx_i0
)
{
sweep_tile_span
(
spans
[
I0
],
[
&
](
auto
dstr_idx_i0
)
{
constexpr
auto
in_dstr_idx
=
make_tuple
(
dstr_idx_i0
,
dstr_idx_i1
);
constexpr
auto
in_dstr_idx
=
make_tuple
(
dstr_idx_i0
,
dstr_idx_i1
);
constexpr
auto
out_dstr_idx
=
make_tuple
(
dstr_idx_i0
);
constexpr
auto
out_dstr_idx
=
make_tuple
(
dstr_idx_i0
);
auto
x
=
ck_tile
::
type_convert
<
ComputeDataType
>
(
x_tensor
[
in_dstr_idx
]);
auto
x
=
ck_tile
::
type_convert
<
ComputeDataType
>
(
x_tensor
[
in_dstr_idx
]);
if
(
kWelford
)
welford_update
(
mean_tensor
(
out_dstr_idx
),
{
var_tensor
(
out_dstr_idx
),
welford_update
(
mean_tensor
(
out_dstr_idx
),
x
,
var_tensor
(
out_dstr_idx
),
cur_count_
,
x
,
constant
<
kFastFDiv
>
{});
cur_count_
,
constant
<
kFastFDiv
>
{});
}
else
{
mean_tensor
(
out_dstr_idx
)
+=
x
;
var_tensor
(
out_dstr_idx
)
+=
x
*
x
;
}
});
});
}
}
});
});
...
@@ -91,10 +98,11 @@ struct BlockWelford
...
@@ -91,10 +98,11 @@ struct BlockWelford
};
};
template
<
typename
Problem_
,
typename
Policy_
=
void
>
template
<
typename
Problem_
,
typename
Policy_
=
void
>
struct
Block
Welford
Sync
struct
Block
NormReduce
Sync
{
{
using
Problem
=
remove_cvref_t
<
Problem_
>
;
using
Problem
=
remove_cvref_t
<
Problem_
>
;
static
constexpr
bool
kFastFDiv
=
Problem
::
kFastFDiv
;
static
constexpr
bool
kFastFDiv
=
Problem
::
kFastFDiv
;
static
constexpr
bool
kWelford
=
Problem
::
kWelford
;
template
<
typename
MeanDistributedTensor_
,
typename
VarDistributedTensor_
>
template
<
typename
MeanDistributedTensor_
,
typename
VarDistributedTensor_
>
CK_TILE_DEVICE
void
CK_TILE_DEVICE
void
...
@@ -152,36 +160,48 @@ struct BlockWelfordSync
...
@@ -152,36 +160,48 @@ struct BlockWelfordSync
(
number
<
lid_over_rid_derivative
<<
istage
.
value
>
{}.
value
);
(
number
<
lid_over_rid_derivative
<<
istage
.
value
>
{}.
value
);
// pull data from remote lane
// pull data from remote lane
const
auto
v_remote_mean
=
warp_shuffle
(
v_local_mean
,
src_lane
);
const
auto
v_remote_mean
=
warp_shuffle
(
v_local_mean
,
src_lane
);
const
auto
v_remote_var
=
warp_shuffle
(
v_local_var
,
src_lane
);
const
auto
v_remote_var
=
warp_shuffle
(
v_local_var
,
src_lane
);
const
auto
v_remote_count
=
warp_shuffle
(
v_local_count
,
src_lane
);
if
(
kWelford
)
{
// welford merge
const
auto
v_remote_count
=
warp_shuffle
(
v_local_count
,
src_lane
);
welford_merge
(
v_local_mean
,
v_local_var
,
// norm_reduce merge
v_local_count
,
welford_merge
(
v_local_mean
,
v_remote_mean
,
v_local_var
,
v_remote_var
,
v_local_count
,
v_remote_count
,
v_remote_mean
,
constant
<
kFastFDiv
>
{});
v_remote_var
,
v_remote_count
,
constant
<
kFastFDiv
>
{});
}
else
{
v_local_mean
+=
v_remote_mean
;
v_local_var
+=
v_remote_var
;
}
});
});
}
}
});
});
mean_tensor
.
get_thread_buffer
()(
i
)
=
v_local_mean
;
mean_tensor
.
get_thread_buffer
()(
i
)
=
v_local_mean
;
var_tensor
.
get_thread_buffer
()(
i
)
=
v_local_var
;
var_tensor
.
get_thread_buffer
()(
i
)
=
v_local_var
;
if
(
kWelford
)
count
=
v_local_count
;
{
count
=
v_local_count
;
}
});
});
}
}
};
};
template
<
typename
Problem_
,
typename
Policy_
=
void
>
template
<
typename
Problem_
,
typename
Policy_
=
void
>
struct
Block
Welford
CrossWarpSync
struct
Block
NormReduce
CrossWarpSync
{
{
using
Problem
=
remove_cvref_t
<
Problem_
>
;
using
Problem
=
remove_cvref_t
<
Problem_
>
;
using
BlockShape
=
typename
Problem
::
BlockShape
;
using
BlockShape
=
typename
Problem
::
BlockShape
;
static
constexpr
bool
kFastFDiv
=
Problem
::
kFastFDiv
;
static
constexpr
bool
kFastFDiv
=
Problem
::
kFastFDiv
;
static
constexpr
bool
kWelford
=
Problem
::
kWelford
;
using
smem_dtype
=
std
::
conditional_t
<
kWelford
,
fp32x4_t
,
fp32x2_t
>
;
template
<
typename
MeanDistributedTensor_
>
template
<
typename
MeanDistributedTensor_
>
CK_TILE_DEVICE
static
constexpr
index_t
GetReduceWarps
()
CK_TILE_DEVICE
static
constexpr
index_t
GetReduceWarps
()
...
@@ -252,7 +272,7 @@ struct BlockWelfordCrossWarpSync
...
@@ -252,7 +272,7 @@ struct BlockWelfordCrossWarpSync
static_assert
(
thread_buf_size
==
VarDistributedTensor_
::
get_thread_buffer_size
());
static_assert
(
thread_buf_size
==
VarDistributedTensor_
::
get_thread_buffer_size
());
// Note: we always pack everything into fp32x4
// Note: we always pack everything into fp32x4
fp32x4_t
*
smem_ptr
=
reinterpret_cast
<
fp32x4_t
*>
(
smem
);
smem_dtype
*
smem_ptr
=
reinterpret_cast
<
smem_dtype
*>
(
smem
);
const
index_t
lane_id
=
get_lane_id
();
const
index_t
lane_id
=
get_lane_id
();
const
index_t
warp_id
=
get_warp_id
();
const
index_t
warp_id
=
get_warp_id
();
constexpr
auto
num_reduce_warps
=
GetReduceWarps
<
MeanDistributedTensor_
>
();
constexpr
auto
num_reduce_warps
=
GetReduceWarps
<
MeanDistributedTensor_
>
();
...
@@ -267,11 +287,13 @@ struct BlockWelfordCrossWarpSync
...
@@ -267,11 +287,13 @@ struct BlockWelfordCrossWarpSync
if
(
lane_id
==
0
)
if
(
lane_id
==
0
)
{
{
static_for
<
0
,
thread_buf_size
,
1
>
{}([
&
](
auto
i
)
{
static_for
<
0
,
thread_buf_size
,
1
>
{}([
&
](
auto
i
)
{
fp32x4_t
local_scratch_
;
smem_dtype
local_scratch_
;
local_scratch_
[
0
]
=
bit_cast
<
float
>
(
mean_tensor
.
get_thread_buffer
()[
i
]);
local_scratch_
[
0
]
=
bit_cast
<
float
>
(
mean_tensor
.
get_thread_buffer
()[
i
]);
local_scratch_
[
1
]
=
bit_cast
<
float
>
(
var_tensor
.
get_thread_buffer
()[
i
]);
local_scratch_
[
1
]
=
bit_cast
<
float
>
(
var_tensor
.
get_thread_buffer
()[
i
]);
local_scratch_
[
2
]
=
bit_cast
<
float
>
(
count
);
if
(
kWelford
)
{
local_scratch_
[
2
]
=
bit_cast
<
float
>
(
count
);
}
smem_ptr
[
smem_offset
+
i
*
num_warps
]
=
local_scratch_
;
smem_ptr
[
smem_offset
+
i
*
num_warps
]
=
local_scratch_
;
});
});
}
}
...
@@ -280,7 +302,7 @@ struct BlockWelfordCrossWarpSync
...
@@ -280,7 +302,7 @@ struct BlockWelfordCrossWarpSync
// load from smem. here we let everythread to do compute :)
// load from smem. here we let everythread to do compute :)
index_t
local_warp_id
=
warp_id
/
num_reduce_warps
;
index_t
local_warp_id
=
warp_id
/
num_reduce_warps
;
index_t
local_smem_os
=
local_warp_id
*
num_reduce_warps
;
index_t
local_smem_os
=
local_warp_id
*
num_reduce_warps
;
fp32x4_t
all_scratch
[
thread_buf_size
*
num_reduce_warps
];
smem_dtype
all_scratch
[
thread_buf_size
*
num_reduce_warps
];
static_for
<
0
,
thread_buf_size
,
1
>
{}([
&
](
auto
i_0
)
{
static_for
<
0
,
thread_buf_size
,
1
>
{}([
&
](
auto
i_0
)
{
static_for
<
0
,
num_reduce_warps
,
1
>
{}([
&
](
auto
i_1
)
{
static_for
<
0
,
num_reduce_warps
,
1
>
{}([
&
](
auto
i_1
)
{
all_scratch
[
i_0
*
num_reduce_warps
+
i_1
]
=
all_scratch
[
i_0
*
num_reduce_warps
+
i_1
]
=
...
@@ -293,32 +315,40 @@ struct BlockWelfordCrossWarpSync
...
@@ -293,32 +315,40 @@ struct BlockWelfordCrossWarpSync
static_for
<
0
,
thread_buf_size
,
1
>
{}([
&
](
auto
i_0
)
{
static_for
<
0
,
thread_buf_size
,
1
>
{}([
&
](
auto
i_0
)
{
// TODO: use descriptor for this
// TODO: use descriptor for this
auto
v_local
=
all_scratch
[
i_0
*
num_reduce_warps
];
auto
v_local
=
all_scratch
[
i_0
*
num_reduce_warps
];
auto
v_local_mean
=
bit_cast
<
DataType
>
(
v_local
[
0
]);
auto
v_local_mean
=
bit_cast
<
DataType
>
(
v_local
[
0
]);
auto
v_local_var
=
bit_cast
<
DataType
>
(
v_local
[
1
]);
auto
v_local_var
=
bit_cast
<
DataType
>
(
v_local
[
1
]);
auto
v_local_count
=
bit_cast
<
int
>
(
v_local
[
2
]);
int
v_local_count
=
kWelford
?
bit_cast
<
int
>
(
v_local
[
2
])
:
0
;
// further reduce mean/var
// further reduce mean/var
static_for
<
0
,
num_reduce_warps
-
1
,
1
>
{}([
&
](
auto
i_1_n1
)
{
static_for
<
0
,
num_reduce_warps
-
1
,
1
>
{}([
&
](
auto
i_1_n1
)
{
constexpr
auto
i_1
=
number
<
i_1_n1
+
1
>
{};
constexpr
auto
i_1
=
number
<
i_1_n1
+
1
>
{};
const
fp32x4_t
v_remote
=
all_scratch
[
i_0
*
num_reduce_warps
+
i_1
];
const
smem_dtype
v_remote
=
all_scratch
[
i_0
*
num_reduce_warps
+
i_1
];
const
auto
v_remote_mean
=
bit_cast
<
DataType
>
(
v_remote
[
0
]);
const
auto
v_remote_mean
=
bit_cast
<
DataType
>
(
v_remote
[
0
]);
const
auto
v_remote_var
=
bit_cast
<
DataType
>
(
v_remote
[
1
]);
const
auto
v_remote_var
=
bit_cast
<
DataType
>
(
v_remote
[
1
]);
const
auto
v_remote_count
=
bit_cast
<
int
>
(
v_remote
[
2
]);
if
(
kWelford
)
{
welford_merge
(
v_local_mean
,
const
auto
v_remote_count
=
bit_cast
<
int
>
(
v_remote
[
2
]);
v_local_var
,
v_local_count
,
welford_merge
(
v_local_mean
,
v_remote_mean
,
v_local_var
,
v_remote_var
,
v_local_count
,
v_remote_count
,
v_remote_mean
,
constant
<
kFastFDiv
>
{});
v_remote_var
,
v_remote_count
,
constant
<
kFastFDiv
>
{});
}
else
{
v_local_mean
+=
v_remote_mean
;
v_local_var
+=
v_remote_var
;
}
});
});
mean_tensor
.
get_thread_buffer
()(
i_0
)
=
v_local_mean
;
mean_tensor
.
get_thread_buffer
()(
i_0
)
=
v_local_mean
;
var_tensor
.
get_thread_buffer
()(
i_0
)
=
v_local_var
;
var_tensor
.
get_thread_buffer
()(
i_0
)
=
v_local_var
;
if
(
kWelford
)
count
=
v_local_count
;
count
=
v_local_count
;
});
});
}
}
};
};
...
...
include/ck_tile/ops/
welford
/block/block_
welford
_problem.hpp
→
include/ck_tile/ops/
norm_reduce
/block/block_
norm_reduce
_problem.hpp
View file @
8b49f207
...
@@ -7,13 +7,18 @@
...
@@ -7,13 +7,18 @@
namespace
ck_tile
{
namespace
ck_tile
{
template
<
typename
XDataType_
,
typename
ComputeDataType_
,
typename
BlockShape_
,
bool
kFastFDiv_
>
template
<
typename
XDataType_
,
struct
BlockWelfordProblem
typename
ComputeDataType_
,
typename
BlockShape_
,
bool
kFastFDiv_
,
bool
kWelford_
>
struct
BlockNormReduceProblem
{
{
using
XDataType
=
remove_cvref_t
<
XDataType_
>
;
using
XDataType
=
remove_cvref_t
<
XDataType_
>
;
using
ComputeDataType
=
remove_cvref_t
<
ComputeDataType_
>
;
using
ComputeDataType
=
remove_cvref_t
<
ComputeDataType_
>
;
using
BlockShape
=
remove_cvref_t
<
BlockShape_
>
;
using
BlockShape
=
remove_cvref_t
<
BlockShape_
>
;
static
constexpr
bool
kFastFDiv
=
kFastFDiv_
;
static
constexpr
bool
kFastFDiv
=
kFastFDiv_
;
static
constexpr
bool
kWelford
=
kWelford_
;
};
};
}
// namespace ck_tile
}
// namespace ck_tile
include/ck_tile/ops/
welford
/thread/thread_welford.hpp
→
include/ck_tile/ops/
norm_reduce
/thread/thread_welford.hpp
View file @
8b49f207
File moved
include/ck_tile/ops/permute.hpp
View file @
8b49f207
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
5
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
...
...
include/ck_tile/ops/reduce.hpp
View file @
8b49f207
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
5
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
...
...
include/ck_tile/ops/rmsnorm2d.hpp
View file @
8b49f207
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
5
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
...
...
include/ck_tile/ops/smoothquant.hpp
View file @
8b49f207
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
5
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
...
...
include/ck_tile/ops/smoothquant/kernel/smoothquant_kernel.hpp
View file @
8b49f207
...
@@ -19,7 +19,8 @@ struct SmoothquantHostArgs
...
@@ -19,7 +19,8 @@ struct SmoothquantHostArgs
index_t
m
;
index_t
m
;
index_t
n
;
index_t
n
;
index_t
stride
;
// row_stride
index_t
x_stride
;
// input row_stride
index_t
y_stride
;
// output row_stride
};
};
// TODO: Extract some type to wrapper class
// TODO: Extract some type to wrapper class
...
@@ -58,14 +59,21 @@ struct Smoothquant
...
@@ -58,14 +59,21 @@ struct Smoothquant
index_t
m
;
index_t
m
;
index_t
n
;
index_t
n
;
index_t
stride
;
// row_stride
index_t
x_stride
;
// input row_stride
index_t
y_stride
;
// out row_stride
};
};
using
Hargs
=
SmoothquantHostArgs
;
using
Hargs
=
SmoothquantHostArgs
;
CK_TILE_HOST
static
constexpr
Kargs
MakeKargs
(
const
Hargs
&
hargs
)
CK_TILE_HOST
static
constexpr
Kargs
MakeKargs
(
const
Hargs
&
hargs
)
{
{
return
Kargs
{
return
Kargs
{
hargs
.
p_x
,
hargs
.
p_x
,
hargs
.
p_xscale
,
hargs
.
p_yscale
,
hargs
.
p_qy
,
hargs
.
m
,
hargs
.
n
,
hargs
.
stride
};
hargs
.
p_xscale
,
hargs
.
p_yscale
,
hargs
.
p_qy
,
hargs
.
m
,
hargs
.
n
,
hargs
.
x_stride
,
hargs
.
y_stride
};
}
}
CK_TILE_HOST
static
constexpr
auto
GridSize
(
const
Hargs
&
hargs
)
CK_TILE_HOST
static
constexpr
auto
GridSize
(
const
Hargs
&
hargs
)
...
@@ -116,7 +124,7 @@ struct Smoothquant
...
@@ -116,7 +124,7 @@ struct Smoothquant
const
auto
tmp_
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
const
auto
tmp_
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
static_cast
<
const
XDataType
*>
(
kargs
.
p_x
),
static_cast
<
const
XDataType
*>
(
kargs
.
p_x
),
make_tuple
(
kargs
.
m
,
kargs
.
n
),
make_tuple
(
kargs
.
m
,
kargs
.
n
),
make_tuple
(
kargs
.
stride
,
1
),
make_tuple
(
kargs
.
x_
stride
,
1
),
number
<
Vector_N
>
{},
number
<
Vector_N
>
{},
number
<
1
>
{});
number
<
1
>
{});
...
@@ -157,7 +165,7 @@ struct Smoothquant
...
@@ -157,7 +165,7 @@ struct Smoothquant
auto
tmp_
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
auto
tmp_
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
static_cast
<
QYDataType
*>
(
kargs
.
p_qy
),
static_cast
<
QYDataType
*>
(
kargs
.
p_qy
),
make_tuple
(
kargs
.
m
,
kargs
.
n
),
make_tuple
(
kargs
.
m
,
kargs
.
n
),
make_tuple
(
kargs
.
stride
,
1
),
make_tuple
(
kargs
.
y_
stride
,
1
),
number
<
Vector_N
>
{},
number
<
Vector_N
>
{},
number
<
1
>
{});
number
<
1
>
{});
...
...
include/ck_tile/ops/softmax.hpp
View file @
8b49f207
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
5
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
...
...
include/ck_tile/ops/topk.hpp
View file @
8b49f207
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
5
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
...
...
include/ck_tile/ops/topk_softmax.hpp
View file @
8b49f207
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
5
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
...
...
Prev
1
…
4
5
6
7
8
9
10
11
12
…
14
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment