Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
f3baea0d
Unverified
Commit
f3baea0d
authored
Sep 12, 2023
by
Chao Liu
Committed by
GitHub
Sep 12, 2023
Browse files
Gemm+softmax+gemm (#9)
* adding gemm+softmax+gemm
parent
cfdce3eb
Changes
11
Hide whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
654 additions
and
91 deletions
+654
-91
example/91_tile_program/CMakeLists.txt
example/91_tile_program/CMakeLists.txt
+1
-0
example/91_tile_program/gemm_gemm.hpp
example/91_tile_program/gemm_gemm.hpp
+0
-42
example/91_tile_program/gemm_softmax_gemm.cpp
example/91_tile_program/gemm_softmax_gemm.cpp
+155
-0
example/91_tile_program/gemm_softmax_gemm.hpp
example/91_tile_program/gemm_softmax_gemm.hpp
+395
-0
example/91_tile_program/reference_softmax.hpp
example/91_tile_program/reference_softmax.hpp
+45
-0
example/91_tile_program/softmax.cpp
example/91_tile_program/softmax.cpp
+6
-40
include/ck/tile_program/block_tile/block_gemm_areg_bsmem_creg_v1_default_policy.hpp
...ock_tile/block_gemm_areg_bsmem_creg_v1_default_policy.hpp
+4
-0
include/ck/tile_program/block_tile/block_reduce.hpp
include/ck/tile_program/block_tile/block_reduce.hpp
+31
-4
include/ck/tile_program/tile/tile_elementwise.hpp
include/ck/tile_program/tile/tile_elementwise.hpp
+1
-1
include/ck/tile_program/warp_tile/warp_gemm.hpp
include/ck/tile_program/warp_tile/warp_gemm.hpp
+3
-0
library/include/ck/library/utility/host_tensor.hpp
library/include/ck/library/utility/host_tensor.hpp
+13
-4
No files found.
example/91_tile_program/CMakeLists.txt
View file @
f3baea0d
...
@@ -4,3 +4,4 @@ add_example_executable(example_gemm gemm.cpp)
...
@@ -4,3 +4,4 @@ add_example_executable(example_gemm gemm.cpp)
add_example_executable
(
example_gemm_gemm gemm_gemm.cpp
)
add_example_executable
(
example_gemm_gemm gemm_gemm.cpp
)
add_example_executable
(
example_reduce reduce.cpp
)
add_example_executable
(
example_reduce reduce.cpp
)
add_example_executable
(
example_softmax softmax.cpp
)
add_example_executable
(
example_softmax softmax.cpp
)
add_example_executable
(
example_gemm_softmax_gemm gemm_softmax_gemm.cpp
)
example/91_tile_program/gemm_gemm.hpp
View file @
f3baea0d
...
@@ -215,7 +215,6 @@ struct GemmGemm
...
@@ -215,7 +215,6 @@ struct GemmGemm
// init Acc1
// init Acc1
tile_elementwise_inout
([](
auto
&
acc1
)
{
acc1
=
0
;
},
acc1_block_tile
);
tile_elementwise_inout
([](
auto
&
acc1
)
{
acc1
=
0
;
},
acc1_block_tile
);
#if 0
index_t
iN0
=
0
;
index_t
iN0
=
0
;
do
do
...
@@ -255,47 +254,6 @@ struct GemmGemm
...
@@ -255,47 +254,6 @@ struct GemmGemm
iN0
+=
kN0PerBlock
;
iN0
+=
kN0PerBlock
;
}
while
(
iN0
<
N0
);
}
while
(
iN0
<
N0
);
#else
index_t
iN0
=
0
;
do
{
// load b1
const
auto
b1_block_tile
=
load_tile
(
b1_dram_block_window
);
// Block GEMM0 pipeline: acc0 = a0 * b0
const
auto
acc0_block_tile
=
block_gemm0_pipeline
(
a0_dram_block_window
,
b0_dram_block_window
,
K0
/
kK0PerBlock
,
p_smem_char
);
// type cast acc0 into c0
const
auto
c0_block_tile
=
tile_elementwise_in
(
type_convert
<
C0DataType
,
Acc0DataType
>
,
acc0_block_tile
);
// Block GEMM1: acc1 += c0 * b1
{
// wait for block gemm0 pipeline to finish
ps
.
block_sync_lds
();
store_tile
(
b1_lds_block_window
,
b1_block_tile
);
// wait for store_tile to finish
ps
.
block_sync_lds
();
// acc1 += c0 * b1
block_gemm1
(
acc1_block_tile
,
c0_block_tile
,
b1_lds_block_window
);
// wait for block gemm1 to finish
ps
.
block_sync_lds
();
}
// move tile windows
move_tile_window
(
b0_dram_block_window
,
{
kN0PerBlock
,
0
});
move_tile_window
(
b1_dram_block_window
,
{
0
,
kN0PerBlock
});
iN0
+=
kN0PerBlock
;
}
while
(
iN0
<
N0
);
#endif
// type cast acc1 into c1
// type cast acc1 into c1
const
auto
c1_block_tile
=
const
auto
c1_block_tile
=
...
...
example/91_tile_program/gemm_softmax_gemm.cpp
0 → 100644
View file @
f3baea0d
#include <cstring>
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_description/cluster_descriptor.hpp"
#include "ck/tensor/tensor_view.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/fill.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "reference_gemm.hpp"
#include "reference_softmax.hpp"
#include "gemm_softmax_gemm.hpp"
int
main
(
int
argc
,
char
*
argv
[])
{
using
A0DataType
=
ck
::
half_t
;
using
B0DataType
=
ck
::
half_t
;
using
Acc0DataType
=
float
;
using
C0DataType
=
ck
::
half_t
;
using
D0DataType
=
ck
::
half_t
;
using
B1DataType
=
ck
::
half_t
;
using
Acc1DataType
=
float
;
using
C1DataType
=
ck
::
half_t
;
ck
::
index_t
M0
=
13312
;
ck
::
index_t
N0
=
4096
;
ck
::
index_t
K0
=
128
;
ck
::
index_t
N1
=
128
;
if
(
argc
==
5
)
{
M0
=
std
::
stoi
(
argv
[
1
]);
N0
=
std
::
stoi
(
argv
[
2
]);
K0
=
std
::
stoi
(
argv
[
3
]);
N1
=
std
::
stoi
(
argv
[
4
]);
}
std
::
array
<
ck
::
index_t
,
2
>
a0_lengths
{
M0
,
K0
};
std
::
array
<
ck
::
index_t
,
2
>
a0_strides
{
K0
,
1
};
std
::
array
<
ck
::
index_t
,
2
>
b0_lengths
{
N0
,
K0
};
std
::
array
<
ck
::
index_t
,
2
>
b0_strides
{
K0
,
1
};
std
::
array
<
ck
::
index_t
,
2
>
c0_lengths
{
M0
,
N0
};
std
::
array
<
ck
::
index_t
,
2
>
c0_strides
{
N0
,
1
};
std
::
array
<
ck
::
index_t
,
2
>
d0_lengths
{
M0
,
N0
};
std
::
array
<
ck
::
index_t
,
2
>
d0_strides
{
N0
,
1
};
std
::
array
<
ck
::
index_t
,
2
>
b1_lengths
{
N1
,
N0
};
std
::
array
<
ck
::
index_t
,
2
>
b1_strides
{
N0
,
1
};
std
::
array
<
ck
::
index_t
,
2
>
c1_lengths
{
M0
,
N1
};
std
::
array
<
ck
::
index_t
,
2
>
c1_strides
{
N1
,
1
};
// host verify
Tensor
<
A0DataType
>
a0_host
(
a0_lengths
,
a0_strides
);
Tensor
<
B0DataType
>
b0_host
(
b0_lengths
,
b0_strides
);
Tensor
<
C0DataType
>
c0_host_ref
(
c0_lengths
,
c0_strides
);
Tensor
<
D0DataType
>
d0_host_ref
(
d0_lengths
,
d0_strides
);
Tensor
<
B1DataType
>
b1_host
(
b1_lengths
,
b1_strides
);
Tensor
<
C1DataType
>
c1_host_ref
(
c1_lengths
,
c1_strides
);
Tensor
<
C1DataType
>
c1_host_dev
(
c1_lengths
,
c1_strides
);
#if 1
ck
::
utils
::
FillUniformDistributionIntegerValue
<
A0DataType
>
{
-
3.
f
,
3.
f
}(
a0_host
);
ck
::
utils
::
FillUniformDistributionIntegerValue
<
B0DataType
>
{
-
3.
f
,
3.
f
}(
b0_host
);
ck
::
utils
::
FillUniformDistributionIntegerValue
<
B1DataType
>
{
-
3.
f
,
3.
f
}(
b1_host
);
#elif 0
ck
::
utils
::
FillUniformDistribution
<
A0DataType
>
{
-
3.
f
,
3.
f
}(
a0_host
);
ck
::
utils
::
FillUniformDistribution
<
B0DataType
>
{
-
3.
f
,
3.
f
}(
b0_host
);
ck
::
utils
::
FillUniformDistribution
<
B1DataType
>
{
-
3.
f
,
3.
f
}(
b1_host
);
#else
ck
::
utils
::
FillConstant
<
A0DataType
>
{
1.0
f
}(
a0_host
);
ck
::
utils
::
FillConstant
<
A0DataType
>
{
1.0
f
}(
b0_host
);
ck
::
utils
::
FillConstant
<
A0DataType
>
{
1.0
f
}(
b1_host
);
#endif
// reference
reference_gemm
<
A0DataType
,
B0DataType
,
C0DataType
,
float
>
(
a0_host
,
b0_host
,
c0_host_ref
);
reference_softmax
<
C0DataType
,
float
,
D0DataType
>
(
c0_host_ref
,
d0_host_ref
);
reference_gemm
<
D0DataType
,
B1DataType
,
C1DataType
,
float
>
(
d0_host_ref
,
b1_host
,
c1_host_ref
);
DeviceMem
a0_buf
(
sizeof
(
A0DataType
)
*
a0_host
.
GetElementSpaceSize
());
DeviceMem
b0_buf
(
sizeof
(
B0DataType
)
*
b0_host
.
GetElementSpaceSize
());
DeviceMem
b1_buf
(
sizeof
(
B1DataType
)
*
b1_host
.
GetElementSpaceSize
());
DeviceMem
c1_buf
(
sizeof
(
C1DataType
)
*
c1_host_ref
.
GetElementSpaceSize
());
a0_buf
.
ToDevice
(
a0_host
.
mData
.
data
());
b0_buf
.
ToDevice
(
b0_host
.
mData
.
data
());
b1_buf
.
ToDevice
(
b1_host
.
mData
.
data
());
constexpr
ck
::
index_t
kM0PerBlock
=
128
;
constexpr
ck
::
index_t
kN0PerBlock
=
128
;
constexpr
ck
::
index_t
kK0PerBlock
=
32
;
constexpr
ck
::
index_t
kN1PerBlock
=
128
;
constexpr
ck
::
index_t
kBlockSize
=
256
;
ck
::
index_t
kGridSize
=
(
M0
/
kM0PerBlock
)
*
(
N1
/
kN1PerBlock
);
std
::
cout
<<
"grid size "
<<
kGridSize
<<
std
::
endl
;
float
ave_time
=
launch
(
ProgramServer
{},
GemmSoftmaxGemm
<
A0DataType
,
B0DataType
,
Acc0DataType
,
C0DataType
,
B1DataType
,
Acc1DataType
,
C1DataType
,
kBlockSize
,
kM0PerBlock
,
kN0PerBlock
,
kK0PerBlock
,
kN1PerBlock
>
{},
kGridSize
,
kBlockSize
,
static_cast
<
A0DataType
*>
(
a0_buf
.
GetDeviceBuffer
()),
static_cast
<
B0DataType
*>
(
b0_buf
.
GetDeviceBuffer
()),
static_cast
<
B1DataType
*>
(
b1_buf
.
GetDeviceBuffer
()),
static_cast
<
C1DataType
*>
(
c1_buf
.
GetDeviceBuffer
()),
M0
,
N0
,
K0
,
N1
,
K0
,
// Lda0
K0
,
// Ldb0
N0
,
// Ldb1
N1
);
// Ldc1
c1_buf
.
FromDevice
(
c1_host_dev
.
mData
.
data
());
std
::
size_t
flop
=
std
::
size_t
(
2
)
*
M0
*
N0
*
K0
+
std
::
size_t
(
2
)
*
M0
*
N1
*
N0
;
std
::
size_t
num_btype
=
sizeof
(
A0DataType
)
*
M0
*
K0
+
sizeof
(
B0DataType
)
*
N0
*
K0
+
sizeof
(
B1DataType
)
*
N1
*
N0
+
sizeof
(
C1DataType
)
*
M0
*
N1
;
float
tflops
=
static_cast
<
float
>
(
flop
)
/
1.E9
/
ave_time
;
float
gb_per_sec
=
num_btype
/
1.E6
/
ave_time
;
std
::
cout
<<
"Perf: "
<<
ave_time
<<
" ms, "
<<
tflops
<<
" TFlops, "
<<
gb_per_sec
<<
" GB/s"
<<
std
::
endl
;
// LogRangeAsType<float>(std::cout << "C1 dev: ", c1_host_dev.mData, ", ", 16, 20) << std::endl;
// LogRangeAsType<float>(std::cout << "C1 ref: ", c1_host_ref.mData, ", ", 16, 20) << std::endl;
return
!
ck
::
utils
::
check_err
(
c1_host_dev
,
c1_host_ref
);
}
example/91_tile_program/gemm_softmax_gemm.hpp
0 → 100644
View file @
f3baea0d
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_description/tensor_adaptor.hpp"
#include "tile_program.hpp"
#include "ck/tile_program/tile/tile_distribution.hpp"
#include "ck/tile_program/tile/tile_elementwise.hpp"
#include "ck/tile_program/tile/tile_gemm_shape.hpp"
#include "ck/tile_program/warp_tile/warp_gemm.hpp"
#include "ck/tile_program/block_tile_pipeline/block_gemm_pipeline_agmem_bgmem_creg_v2.hpp"
#include "ck/tile_program/block_tile_pipeline/block_gemm_pipeline_problem.hpp"
#include "ck/tile_program/block_tile/block_gemm_areg_bsmem_creg_v1.hpp"
#include "ck/tile_program/block_tile/block_reduce.hpp"
// C0 = A0 * B0
// D0 = softmax(C0)
// C1 = D0 * B1
template
<
typename
A0DataType
,
typename
B0DataType
,
typename
Acc0DataType
,
typename
C0DataType
,
typename
B1DataType
,
typename
Acc1DataType
,
typename
C1DataType
,
ck
::
index_t
kBlockSize
,
ck
::
index_t
kM0PerBlock
,
ck
::
index_t
kN0PerBlock
,
ck
::
index_t
kK0PerBlock
,
ck
::
index_t
kN1PerBlock
>
struct
GemmSoftmaxGemm
{
// block gemm0 pipeline
using
BlockGemm0Pipeline
=
ck
::
tile_program
::
block
::
BlockGemmPipelineAGmemBGmemCRegV2
<
ck
::
tile_program
::
block
::
BlockGemmPipelineProblem
<
A0DataType
,
B0DataType
,
Acc0DataType
,
kBlockSize
,
ck
::
tile_program
::
TileGemmShape
<
kM0PerBlock
,
kN0PerBlock
,
kK0PerBlock
>>
,
ck
::
tile_program
::
block
::
BlockGemmPipelineAGmemBGmemCRegV2DefaultPolicy
>
;
// block gemm1
using
BlockGemm1
=
ck
::
tile_program
::
block
::
BlockGemmARegBSmemCRegV1
<
ck
::
tile_program
::
block
::
BlockGemmARegBSmemCRegV1Problem
<
C0DataType
,
B1DataType
,
Acc1DataType
,
kBlockSize
,
ck
::
tile_program
::
TileGemmShape
<
kM0PerBlock
,
kN1PerBlock
,
kN0PerBlock
>>
,
ck
::
tile_program
::
block
::
BlockGemmARegBSmemCRegV1DefaultPolicy
>
;
#if 0
// 2d
__host__ __device__ static constexpr auto MakeB1LdsBlockDescriptor()
{
using namespace ck;
constexpr index_t kNPerBlock = kN1PerBlock;
constexpr index_t kKPerBlock = kN0PerBlock;
constexpr auto b_lds_block_desc =
make_naive_tensor_descriptor_packed(make_tuple(kNPerBlock, kKPerBlock), Number<32>{});
return b_lds_block_desc;
}
#else
// fake XOR
__host__
__device__
static
constexpr
auto
MakeB1LdsBlockDescriptor
()
{
using
namespace
ck
;
using
BDataType
=
B1DataType
;
constexpr
index_t
kNPerBlock
=
kN1PerBlock
;
constexpr
index_t
kKPerBlock
=
kN0PerBlock
;
constexpr
auto
b_lds_block_desc_d1_d2_d3
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
kNPerBlock
/
2
,
2
,
kKPerBlock
),
Number
<
kKPerBlock
>
{});
constexpr
index_t
kK1
=
16
/
sizeof
(
BDataType
);
constexpr
auto
b_lds_block_desc_d4_d5_d6
=
transform_tensor_descriptor
(
b_lds_block_desc_d1_d2_d3
,
make_tuple
(
make_xor_transform
(
make_tuple
(
kNPerBlock
/
2
,
kKPerBlock
),
kK1
),
make_pass_through_transform
(
2
)),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
constexpr
auto
b_lds_block_desc_n_k
=
transform_tensor_descriptor
(
b_lds_block_desc_d4_d5_d6
,
make_tuple
(
make_merge_transform
(
make_tuple
(
kNPerBlock
/
2
,
2
)),
make_pass_through_transform
(
kKPerBlock
)),
make_tuple
(
Sequence
<
0
,
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
return
b_lds_block_desc_n_k
;
}
#endif
__host__
__device__
static
constexpr
auto
MakeB1DramTileDistribution
()
{
using
namespace
ck
;
using
namespace
ck
::
tile_program
;
using
BDataType
=
B1DataType
;
constexpr
index_t
kNPerBlock
=
kN1PerBlock
;
constexpr
index_t
kKPerBlock
=
kN0PerBlock
;
constexpr
index_t
K1
=
16
/
sizeof
(
BDataType
);
constexpr
index_t
K0
=
kKPerBlock
/
K1
;
constexpr
index_t
N2
=
get_warp_size
()
/
K0
;
constexpr
index_t
N1
=
kBlockSize
/
get_warp_size
();
constexpr
index_t
N0
=
kNPerBlock
/
(
N2
*
N1
);
return
make_static_tile_distribution
(
StaticTileDistributionEncoding
<
Sequence
<
1
>
,
Tuple
<
Sequence
<
N0
,
N1
,
N2
>
,
Sequence
<
K0
,
K1
>>
,
Tuple
<
Sequence
<
1
>
,
Sequence
<
1
,
2
>>
,
Tuple
<
Sequence
<
1
>
,
Sequence
<
2
,
0
>>
,
Sequence
<
1
,
2
>
,
Sequence
<
0
,
1
>>
{});
}
__host__
__device__
static
constexpr
ck
::
index_t
GetStaticLdsSize
()
{
using
namespace
ck
;
return
math
::
max
(
BlockGemm0Pipeline
::
GetStaticLdsSize
(),
static_cast
<
index_t
>
(
MakeB1LdsBlockDescriptor
().
GetElementSpaceSize
()
*
sizeof
(
B1DataType
)));
}
__host__
__device__
void
operator
()(
ProgramServer
&
ps
,
const
A0DataType
*
p_a0
,
const
B0DataType
*
p_b0
,
const
B1DataType
*
p_b1
,
C1DataType
*
p_c1
,
ck
::
index_t
M0
,
ck
::
index_t
N0
,
ck
::
index_t
K0
,
ck
::
index_t
N1
,
ck
::
index_t
Lda0
,
ck
::
index_t
Ldb0
,
ck
::
index_t
Ldb1
,
ck
::
index_t
Ldc1
)
{
using
namespace
ck
;
using
namespace
ck
::
tile_program
;
using
namespace
ck
::
tile_program
::
block
;
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
// FIXME: assume layout A0[M0, K0], B0[N0, K0], B1[N1, N0], C1[M0, N1]
const
auto
a0_dram_grid
=
make_naive_tensor_view
<
AddressSpaceEnum
::
Global
>
(
p_a0
,
make_tuple
(
M0
,
K0
),
make_tuple
(
Lda0
,
1
),
Number
<
32
>
{},
Number
<
1
>
{});
const
auto
b0_dram_grid
=
make_naive_tensor_view
<
AddressSpaceEnum
::
Global
>
(
p_b0
,
make_tuple
(
N0
,
K0
),
make_tuple
(
Ldb0
,
1
),
Number
<
32
>
{},
Number
<
1
>
{});
const
auto
b1_dram_grid
=
make_naive_tensor_view
<
AddressSpaceEnum
::
Global
>
(
p_b1
,
make_tuple
(
N1
,
N0
),
make_tuple
(
Ldb1
,
1
),
Number
<
32
>
{},
Number
<
1
>
{});
// divide problem
const
auto
id_block
=
ps
.
get_block_id
();
const
auto
num_tile_m0
=
M0
/
kM0PerBlock
;
const
auto
num_tile_n1
=
N1
/
kN1PerBlock
;
const
auto
block2tile
=
ps
(
make_cluster_descriptor
(
make_tuple
(
num_tile_m0
,
num_tile_n1
)));
const
auto
id_tile
=
block2tile
.
CalculateBottomIndex
(
make_tuple
(
id_block
));
const
auto
iM0
=
ps
.
read_first_lane
(
id_tile
.
At
<
0
>
()
*
kM0PerBlock
);
const
auto
iN1
=
ps
.
read_first_lane
(
id_tile
.
At
<
1
>
()
*
kN1PerBlock
);
__shared__
char
p_smem_char
[
GetStaticLdsSize
()];
// A0 DRAM block window
auto
a0_dram_block_window
=
make_tile_window
(
a0_dram_grid
,
make_tuple
(
Number
<
kM0PerBlock
>
{},
Number
<
kK0PerBlock
>
{}),
{
iM0
,
0
});
// B0 DRAM block window
auto
b0_dram_block_window
=
make_tile_window
(
b0_dram_grid
,
make_tuple
(
Number
<
kN0PerBlock
>
{},
Number
<
kK0PerBlock
>
{}),
{
0
,
0
});
// Block GEMM0 pipeline
constexpr
auto
block_gemm0_pipeline
=
BlockGemm0Pipeline
{};
// B1 DRAM window
auto
b1_dram_block_window
=
make_tile_window
(
b1_dram_grid
,
make_tuple
(
Number
<
kN1PerBlock
>
{},
Number
<
kN0PerBlock
>
{}),
{
iN1
,
0
},
MakeB1DramTileDistribution
());
// B1 LDS tensor view: occupies the same LDS allocation as block_gemm0_pipeline
auto
b1_lds_block
=
make_tensor_view
<
AddressSpaceEnum
::
Lds
>
(
reinterpret_cast
<
B1DataType
*>
(
p_smem_char
),
MakeB1LdsBlockDescriptor
());
auto
b1_lds_block_window
=
make_tile_window
(
b1_lds_block
,
make_tuple
(
Number
<
kN1PerBlock
>
{},
Number
<
kN0PerBlock
>
{}),
{
0
,
0
});
// Bock GEMM1
constexpr
auto
block_gemm1
=
BlockGemm1
{};
// Acc0 tile
using
Acc0BlockTileType
=
decltype
(
block_gemm0_pipeline
(
a0_dram_block_window
,
b0_dram_block_window
,
0
,
nullptr
));
// Acc1 tile
auto
acc1_block_tile
=
decltype
(
block_gemm1
(
tile_elementwise_in
(
type_convert
<
C0DataType
,
Acc0DataType
>
,
Acc0BlockTileType
{}),
b1_dram_block_window
)){};
const
auto
f_max
=
[](
auto
v0
,
auto
v1
)
{
return
max
(
v0
,
v1
);
};
const
auto
f_sum
=
[](
auto
v0
,
auto
v1
)
{
return
v0
+
v1
;
};
// init Acc1
tile_elementwise_inout
([](
auto
&
acc1
)
{
acc1
=
0
;
},
acc1_block_tile
);
// m, l tile
auto
m
=
decltype
(
block_tile_reduce
<
Acc0DataType
>
(
Acc0BlockTileType
{},
Sequence
<
1
>
{},
f_max
,
Acc0DataType
{
0
})){};
// init m, l
auto
l
=
make_static_distributed_tensor
<
Acc0DataType
>
(
m
.
GetTileDistribution
());
tile_elementwise_inout
([](
auto
&
m_v
)
{
m_v
=
NumericLimits
<
Acc0DataType
>::
Lowest
();
},
m
);
tile_elementwise_inout
([](
auto
&
l_v
)
{
l_v
=
0
;
},
l
);
index_t
iN0
=
0
;
do
{
// S[i][j] = Q[i] * K[j]
const
auto
acc0_block_tile
=
block_gemm0_pipeline
(
a0_dram_block_window
,
b0_dram_block_window
,
K0
/
kK0PerBlock
,
p_smem_char
);
// rowmax(S[i][j])
auto
m_local
=
block_tile_reduce
<
Acc0DataType
>
(
acc0_block_tile
,
Sequence
<
1
>
{},
f_max
,
NumericLimits
<
Acc0DataType
>::
Lowest
());
block_tile_reduce_sync
(
m_local
,
f_max
);
// m[i][j-1]
const
auto
m_old
=
m
;
// m[i][j]
tile_elementwise_inout
(
[](
auto
&
m_v
,
auto
m_old_v
,
auto
m_local_v
)
{
m_v
=
max
(
m_old_v
,
m_local_v
);
},
m
,
m_old
,
m_local
);
// P[i][j]
auto
p
=
make_static_distributed_tensor
<
Acc0DataType
>
(
acc0_block_tile
.
GetTileDistribution
());
constexpr
auto
p_spans
=
decltype
(
p
)
::
GetDistributedSpans
();
sweep_tile_span
(
p_spans
[
I0
],
[
&
](
auto
idx0
)
{
constexpr
auto
i_idx
=
make_tuple
(
idx0
);
const
auto
m_v
=
m
.
GetElementFromTileDistributedIndices
(
i_idx
);
sweep_tile_span
(
p_spans
[
I1
],
[
&
](
auto
idx1
)
{
constexpr
auto
i_j_idx
=
make_tuple
(
idx0
,
idx1
);
const
auto
s_v
=
acc0_block_tile
.
GetElementFromTileDistributedIndices
(
i_j_idx
);
const
auto
p_v
=
math
::
exp
(
s_v
-
m_v
);
p
.
SetElementFromTileDistributedIndices
(
i_j_idx
,
p_v
);
});
});
// rowsum(P[i][j])
auto
rowsum_p
=
block_tile_reduce
<
Acc0DataType
>
(
p
,
Sequence
<
1
>
{},
f_sum
,
Acc0DataType
{
0
});
block_tile_reduce_sync
(
rowsum_p
,
f_sum
);
// l[i][j], O[i][j]
sweep_tile_span
(
p_spans
[
I0
],
[
&
](
auto
idx0
)
{
constexpr
auto
i_idx
=
make_tuple
(
idx0
);
const
auto
m_old_v
=
m_old
.
GetElementFromTileDistributedIndices
(
i_idx
);
const
auto
m_v
=
m
.
GetElementFromTileDistributedIndices
(
i_idx
);
const
auto
l_old_v
=
l
.
GetElementFromTileDistributedIndices
(
i_idx
);
const
auto
tmp
=
math
::
exp
(
m_old_v
-
m_v
);
const
auto
tmp2
=
1
/
tmp
;
auto
l_v
=
tmp
*
l_old_v
+
rowsum_p
.
GetElementFromTileDistributedIndices
(
i_idx
);
l
.
SetElementFromTileDistributedIndices
(
i_idx
,
l_v
);
sweep_tile_span
(
p_spans
[
I1
],
[
&
](
auto
idx1
)
{
constexpr
auto
i_j_idx
=
make_tuple
(
idx0
,
idx1
);
// O[i][j]
const
auto
o_old_v
=
acc1_block_tile
.
GetElementFromTileDistributedIndices
(
i_j_idx
);
#if 0 // debug
// this use the same equation from FA v2 paper, but produce -nan
const auto o_v = o_old_v * tmp2;
#elif
1
// this use different equation from FA v2 paper, but produce correct result
(
void
)
tmp2
;
const
auto
o_v
=
o_old_v
*
tmp
;
#endif
acc1_block_tile
.
SetElementFromTileDistributedIndices
(
i_j_idx
,
o_v
);
});
});
// type cast p into a1
const
auto
c0_block_tile
=
tile_elementwise_in
(
type_convert
<
C0DataType
,
Acc0DataType
>
,
p
);
// Block GEMM1: acc1 += c0 * b1
{
// load b1
const
auto
b1_block_tile
=
load_tile
(
b1_dram_block_window
);
// wait for block gemm0 pipeline to finish
ps
.
block_sync_lds
();
store_tile
(
b1_lds_block_window
,
b1_block_tile
);
// wait for store_tile to finish
ps
.
block_sync_lds
();
// acc1 += c0 * b1
block_gemm1
(
acc1_block_tile
,
c0_block_tile
,
b1_lds_block_window
);
// wait for block gemm1 to finish
ps
.
block_sync_lds
();
}
// move tile windows
move_tile_window
(
b0_dram_block_window
,
{
kN0PerBlock
,
0
});
move_tile_window
(
b1_dram_block_window
,
{
0
,
kN0PerBlock
});
iN0
+=
kN0PerBlock
;
}
while
(
iN0
<
N0
);
// o[i][J-1]
constexpr
auto
o_spans
=
decltype
(
acc1_block_tile
)
::
GetDistributedSpans
();
sweep_tile_span
(
o_spans
[
I0
],
[
&
](
auto
idx0
)
{
constexpr
auto
i_idx
=
make_tuple
(
idx0
);
const
auto
l_v
=
l
.
GetElementFromTileDistributedIndices
(
i_idx
);
const
auto
tmp
=
1
/
l_v
;
sweep_tile_span
(
o_spans
[
I1
],
[
&
](
auto
idx1
)
{
constexpr
auto
i_j_idx
=
make_tuple
(
idx0
,
idx1
);
const
auto
o_v
=
acc1_block_tile
.
GetElementFromTileDistributedIndices
(
i_j_idx
);
const
auto
o_new_v
=
o_v
*
tmp
;
acc1_block_tile
.
SetElementFromTileDistributedIndices
(
i_j_idx
,
o_new_v
);
});
});
// type cast acc1 into c1
const
auto
c1_block_tile
=
tile_elementwise_in
(
type_convert
<
C1DataType
,
Acc1DataType
>
,
acc1_block_tile
);
// store c1
auto
c1_dram_grid
=
make_naive_tensor_view
<
AddressSpaceEnum
::
Global
>
(
p_c1
,
make_tuple
(
M0
,
N1
),
make_tuple
(
Ldc1
,
1
),
Number
<
32
>
{},
Number
<
1
>
{});
auto
c1_dram_window
=
make_tile_window
(
c1_dram_grid
,
make_tuple
(
Number
<
kM0PerBlock
>
{},
Number
<
kN1PerBlock
>
{}),
{
iM0
,
iN1
},
c1_block_tile
.
GetTileDistribution
());
store_tile
(
c1_dram_window
,
c1_block_tile
);
}
};
example/91_tile_program/reference_softmax.hpp
0 → 100644
View file @
f3baea0d
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/library/utility/host_tensor.hpp"
template
<
typename
ADataType
,
typename
AccDataType
,
typename
BDataType
>
void
reference_softmax
(
const
Tensor
<
ADataType
>&
a_m_n
,
Tensor
<
BDataType
>&
b_m_n
)
{
auto
f
=
[
&
](
auto
m
)
{
const
int
N
=
a_m_n
.
mDesc
.
GetLengths
()[
1
];
AccDataType
v_max
=
ck
::
NumericLimits
<
ADataType
>::
Lowest
();
// max
for
(
int
n
=
0
;
n
<
N
;
++
n
)
{
const
ADataType
v_a
=
a_m_n
(
m
,
n
);
v_max
=
v_max
<
v_a
?
v_a
:
v_max
;
}
AccDataType
v_exp_sum
=
0
;
// sum
for
(
int
n
=
0
;
n
<
N
;
++
n
)
{
const
ADataType
v_a
=
a_m_n
(
m
,
n
);
v_exp_sum
+=
ck
::
math
::
exp
(
v_a
-
v_max
);
}
// elementwise
for
(
int
n
=
0
;
n
<
N
;
++
n
)
{
const
ADataType
v_a
=
a_m_n
(
m
,
n
);
b_m_n
(
m
,
n
)
=
ck
::
math
::
exp
(
v_a
-
v_max
)
/
v_exp_sum
;
}
};
make_ParallelTensorFunctor
(
f
,
b_m_n
.
mDesc
.
GetLengths
()[
0
])(
std
::
thread
::
hardware_concurrency
());
}
example/91_tile_program/softmax.cpp
View file @
f3baea0d
...
@@ -14,51 +14,14 @@
...
@@ -14,51 +14,14 @@
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "reference_softmax.hpp"
#include "softmax.hpp"
#include "softmax.hpp"
template
<
typename
ADataType
,
typename
AccDataType
,
typename
BDataType
>
void
reference_softmax
(
const
Tensor
<
ADataType
>&
a_m_n
,
Tensor
<
BDataType
>&
b_m_n
)
{
auto
f
=
[
&
](
auto
m
)
{
const
int
N
=
a_m_n
.
mDesc
.
GetLengths
()[
1
];
AccDataType
v_max
=
ck
::
NumericLimits
<
ADataType
>::
Lowest
();
// max
for
(
int
n
=
0
;
n
<
N
;
++
n
)
{
const
ADataType
v_a
=
a_m_n
(
m
,
n
);
v_max
=
v_max
<
v_a
?
v_a
:
v_max
;
}
AccDataType
v_exp_sum
=
0
;
// sum
for
(
int
n
=
0
;
n
<
N
;
++
n
)
{
const
ADataType
v_a
=
a_m_n
(
m
,
n
);
v_exp_sum
+=
ck
::
math
::
exp
(
v_a
-
v_max
);
}
// elementwise
for
(
int
n
=
0
;
n
<
N
;
++
n
)
{
const
ADataType
v_a
=
a_m_n
(
m
,
n
);
b_m_n
(
m
,
n
)
=
ck
::
math
::
exp
(
v_a
-
v_max
)
/
v_exp_sum
;
}
};
make_ParallelTensorFunctor
(
f
,
b_m_n
.
mDesc
.
GetLengths
()[
0
])(
std
::
thread
::
hardware_concurrency
());
}
int
main
(
int
argc
,
char
*
argv
[])
int
main
(
int
argc
,
char
*
argv
[])
{
{
using
ADataType
=
floa
t
;
using
ADataType
=
ck
::
half_
t
;
using
AccDataType
=
float
;
using
AccDataType
=
float
;
using
BDataType
=
floa
t
;
using
BDataType
=
ck
::
half_
t
;
ck
::
index_t
M
=
3328
;
ck
::
index_t
M
=
3328
;
ck
::
index_t
N
=
4096
;
ck
::
index_t
N
=
4096
;
...
@@ -118,5 +81,8 @@ int main(int argc, char* argv[])
...
@@ -118,5 +81,8 @@ int main(int argc, char* argv[])
std
::
cout
<<
"Perf: "
<<
ave_time
<<
" ms, "
<<
gb_per_sec
<<
" GB/s"
<<
std
::
endl
;
std
::
cout
<<
"Perf: "
<<
ave_time
<<
" ms, "
<<
gb_per_sec
<<
" GB/s"
<<
std
::
endl
;
LogRangeAsType
<
float
>
(
std
::
cout
<<
"dev: "
,
b_host_dev
.
mData
,
", "
)
<<
std
::
endl
;
LogRangeAsType
<
float
>
(
std
::
cout
<<
"ref: "
,
b_host_ref
.
mData
,
", "
)
<<
std
::
endl
;
return
!
ck
::
utils
::
check_err
(
b_host_dev
,
b_host_ref
);
return
!
ck
::
utils
::
check_err
(
b_host_dev
,
b_host_ref
);
}
}
include/ck/tile_program/block_tile/block_gemm_areg_bsmem_creg_v1_default_policy.hpp
View file @
f3baea0d
...
@@ -26,6 +26,7 @@ struct BlockGemmARegBSmemCRegV1DefaultPolicy
...
@@ -26,6 +26,7 @@ struct BlockGemmARegBSmemCRegV1DefaultPolicy
{
{
using
namespace
ck
::
tile_program
::
warp
;
using
namespace
ck
::
tile_program
::
warp
;
#if 0
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM;
constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM;
...
@@ -46,6 +47,9 @@ struct BlockGemmARegBSmemCRegV1DefaultPolicy
...
@@ -46,6 +47,9 @@ struct BlockGemmARegBSmemCRegV1DefaultPolicy
{
{
return make_tuple(WarpGemmMfmaF16F16F32M32N32K8{}, 4, 1);
return make_tuple(WarpGemmMfmaF16F16F32M32N32K8{}, 4, 1);
}
}
#else
return
make_tuple
(
WarpGemmMfmaF16F16F32M32N32K8TransposedCDistribution
{},
4
,
1
);
#endif
}
}
};
};
...
...
include/ck/tile_program/block_tile/block_reduce.hpp
View file @
f3baea0d
...
@@ -182,10 +182,10 @@ template <typename AccDataType_,
...
@@ -182,10 +182,10 @@ template <typename AccDataType_,
index_t
...
InReduceDims
,
index_t
...
InReduceDims
,
typename
ReduceFunc
,
typename
ReduceFunc
,
typename
InDataType_
>
typename
InDataType_
>
__host__
__device__
auto
block_tile_reduce
(
const
InDistributedTensor_
&
in_tensor
,
__device__
auto
block_tile_reduce
(
const
InDistributedTensor_
&
in_tensor
,
Sequence
<
InReduceDims
...
>
in_reduce_dims
,
Sequence
<
InReduceDims
...
>
in_reduce_dims
,
const
ReduceFunc
&
reduce_func
,
const
ReduceFunc
&
reduce_func
,
const
InDataType_
&
reduce_init
)
const
InDataType_
&
reduce_init
)
{
{
using
InDataType
=
typename
InDistributedTensor_
::
DataType
;
using
InDataType
=
typename
InDistributedTensor_
::
DataType
;
using
AccDataType
=
remove_cvref_t
<
AccDataType_
>
;
using
AccDataType
=
remove_cvref_t
<
AccDataType_
>
;
...
@@ -222,6 +222,33 @@ __host__ void block_tile_reduce(AccDistributedTensor_&,
...
@@ -222,6 +222,33 @@ __host__ void block_tile_reduce(AccDistributedTensor_&,
{
{
}
}
// FIXME: dummy host function for tile program
template
<
typename
AccDataType_
,
typename
InDistributedTensor_
,
index_t
...
InReduceDims
,
typename
ReduceFunc
,
typename
InDataType_
>
__host__
auto
block_tile_reduce
(
const
InDistributedTensor_
&
,
Sequence
<
InReduceDims
...
>
,
const
ReduceFunc
&
,
const
InDataType_
&
)
{
using
InDataType
=
typename
InDistributedTensor_
::
DataType
;
using
AccDataType
=
remove_cvref_t
<
AccDataType_
>
;
static_assert
(
is_same_v
<
InDataType
,
remove_cvref_t
<
InDataType_
>>
,
"wrong!"
);
// declare acc_tensor
constexpr
auto
acc_dstr
=
make_static_tile_distribution
(
ck
::
tile_program
::
detail
::
make_reduce_tile_distribution_encoding
(
InDistributedTensor_
::
GetTileDistribution
().
GetStaticTileDistributionEncoding
(),
Sequence
<
InReduceDims
...
>
{}));
auto
acc_tensor
=
make_static_distributed_tensor
<
AccDataType
>
(
acc_dstr
);
return
acc_tensor
;
}
// FIXME: dummy host function for tile program
// FIXME: dummy host function for tile program
template
<
typename
AccDistributedTensor_
,
typename
ReduceFunc
>
template
<
typename
AccDistributedTensor_
,
typename
ReduceFunc
>
__host__
void
block_tile_reduce_sync
(
AccDistributedTensor_
&
,
const
ReduceFunc
&
)
__host__
void
block_tile_reduce_sync
(
AccDistributedTensor_
&
,
const
ReduceFunc
&
)
...
...
include/ck/tile_program/tile/tile_elementwise.hpp
View file @
f3baea0d
...
@@ -26,7 +26,7 @@ __host__ __device__ void tile_elementwise_inout(const InOutElementFunc& inout_el
...
@@ -26,7 +26,7 @@ __host__ __device__ void tile_elementwise_inout(const InOutElementFunc& inout_el
type_pack_element
<
0
,
InOutDstrTensors
...
>::
GetThreadBufferSize
();
type_pack_element
<
0
,
InOutDstrTensors
...
>::
GetThreadBufferSize
();
static_for
<
0
,
thread_buffer_size
,
1
>
{}(
static_for
<
0
,
thread_buffer_size
,
1
>
{}(
[
&
](
auto
i
)
{
inout_element_func
(
inout_dstr_tensors
.
GetThreadBuffer
()(
i
)...);
});
[
&
](
auto
i
)
{
inout_element_func
(
inout_dstr_tensors
.
GetThreadBuffer
()
.
At
(
i
)...);
});
}
}
template
<
typename
InElementFunc
,
typename
...
InDstrTensors
>
template
<
typename
InElementFunc
,
typename
...
InDstrTensors
>
...
...
include/ck/tile_program/warp_tile/warp_gemm.hpp
View file @
f3baea0d
...
@@ -22,6 +22,9 @@ using WarpGemmMfmaF16F16F32M16N16K16 =
...
@@ -22,6 +22,9 @@ using WarpGemmMfmaF16F16F32M16N16K16 =
using
WarpGemmMfmaF16F16F32M32N32K16
=
using
WarpGemmMfmaF16F16F32M32N32K16
=
WarpGemmImpl
<
WarpGemmAtrributeMfmaIterateK
<
WarpGemmAttributeMfmaImplF16F16F32M32N32K8
,
2
>>
;
WarpGemmImpl
<
WarpGemmAtrributeMfmaIterateK
<
WarpGemmAttributeMfmaImplF16F16F32M32N32K8
,
2
>>
;
using
WarpGemmMfmaF16F16F32M32N32K8TransposedCDistribution
=
WarpGemmImpl
<
WarpGemmAtrributeMfmaTransposedCDistribution
<
WarpGemmAttributeMfmaImplF16F16F32M32N32K8
>>
;
using
WarpGemmMfmaF16F16F32M32N32K16TransposedCDistribution
=
using
WarpGemmMfmaF16F16F32M32N32K16TransposedCDistribution
=
WarpGemmImpl
<
WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution
<
WarpGemmImpl
<
WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution
<
WarpGemmAttributeMfmaImplF16F16F32M32N32K8
,
WarpGemmAttributeMfmaImplF16F16F32M32N32K8
,
...
...
library/include/ck/library/utility/host_tensor.hpp
View file @
f3baea0d
...
@@ -6,6 +6,7 @@
...
@@ -6,6 +6,7 @@
#include <algorithm>
#include <algorithm>
#include <cassert>
#include <cassert>
#include <iostream>
#include <iostream>
#include <iomanip>
#include <numeric>
#include <numeric>
#include <thread>
#include <thread>
#include <utility>
#include <utility>
...
@@ -19,7 +20,11 @@
...
@@ -19,7 +20,11 @@
#include "ck/library/utility/ranges.hpp"
#include "ck/library/utility/ranges.hpp"
template
<
typename
Range
>
template
<
typename
Range
>
std
::
ostream
&
LogRange
(
std
::
ostream
&
os
,
Range
&&
range
,
std
::
string
delim
)
std
::
ostream
&
LogRange
(
std
::
ostream
&
os
,
Range
&&
range
,
std
::
string
delim
,
int
precision
=
std
::
cout
.
precision
(),
int
width
=
0
)
{
{
bool
first
=
true
;
bool
first
=
true
;
for
(
auto
&&
v
:
range
)
for
(
auto
&&
v
:
range
)
...
@@ -28,13 +33,17 @@ std::ostream& LogRange(std::ostream& os, Range&& range, std::string delim)
...
@@ -28,13 +33,17 @@ std::ostream& LogRange(std::ostream& os, Range&& range, std::string delim)
first
=
false
;
first
=
false
;
else
else
os
<<
delim
;
os
<<
delim
;
os
<<
v
;
os
<<
std
::
setw
(
width
)
<<
std
::
setprecision
(
precision
)
<<
v
;
}
}
return
os
;
return
os
;
}
}
template
<
typename
T
,
typename
Range
>
template
<
typename
T
,
typename
Range
>
std
::
ostream
&
LogRangeAsType
(
std
::
ostream
&
os
,
Range
&&
range
,
std
::
string
delim
)
std
::
ostream
&
LogRangeAsType
(
std
::
ostream
&
os
,
Range
&&
range
,
std
::
string
delim
,
int
precision
=
std
::
cout
.
precision
(),
int
width
=
0
)
{
{
bool
first
=
true
;
bool
first
=
true
;
for
(
auto
&&
v
:
range
)
for
(
auto
&&
v
:
range
)
...
@@ -43,7 +52,7 @@ std::ostream& LogRangeAsType(std::ostream& os, Range&& range, std::string delim)
...
@@ -43,7 +52,7 @@ std::ostream& LogRangeAsType(std::ostream& os, Range&& range, std::string delim)
first
=
false
;
first
=
false
;
else
else
os
<<
delim
;
os
<<
delim
;
os
<<
static_cast
<
T
>
(
v
);
os
<<
std
::
setw
(
width
)
<<
std
::
setprecision
(
precision
)
<<
static_cast
<
T
>
(
v
);
}
}
return
os
;
return
os
;
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment