"3rdparty/googletest-1.13.0/docs/_data/navigation.yml" did not exist on "9484fd1c7db21381124dfd3581fd7f738d5f8e9c"
transpose_removing_padding.cu 7.46 KB
Newer Older
yuguo-Jack's avatar
yuguo-Jack committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
// 
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// 
//     http://www.apache.org/licenses/LICENSE-2.0
// 
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "helper.h"

template <typename T, int VecSize>
__global__ void TransposeRemovingPadding(const T* input_data,
                                         const int* seq_lens,
                                         T* output_data,
                                         const int batch_size,
                                         const int num_head,
                                         const int max_len_this_time,
                                         const int seq_len,
                                         const int head_dim,
                                         const int token_num,
                                         const int elem_cnt,
                                         const int* padding_offset) {
  // transpose and remove padding
  // [batch_size, num_head, max_len_this_time, head_dim] -> [token_num, num_head,
  // head_dim]
  int64_t idx = blockDim.x * blockIdx.x + threadIdx.x;
  const int dim_embed = num_head * head_dim;
  using LoadT = AlignedVector<T, VecSize>;
  LoadT src_vec;

  for (int32_t linear_index = idx * VecSize,
               step = gridDim.x * blockDim.x * VecSize;
       linear_index < elem_cnt;
       linear_index += step) {
    const int token_idx = linear_index / dim_embed;
    const int ori_token_idx =
        token_idx + (padding_offset == nullptr ? 0 : padding_offset[token_idx]);
    const int ori_batch_id = ori_token_idx / seq_len;
    if (seq_lens && seq_lens[ori_batch_id] == 0) continue;
    const int ori_seq_id = ori_token_idx % seq_len;
    const int ori_head_id = (linear_index % dim_embed) / head_dim;
    const int ori_head_lane = (linear_index % dim_embed) % head_dim;
    const int ori_idx = ori_batch_id * num_head * max_len_this_time * head_dim +
                        ori_head_id * max_len_this_time * head_dim +
                        ori_seq_id * head_dim + ori_head_lane;
    Load<T, VecSize>(&input_data[ori_idx], &src_vec);
    Store<T, VecSize>(src_vec, &output_data[linear_index]);
  }
}

template <typename T>
void InvokeTransposeRemovePadding(const T* input_data,
                                  const int* seq_lens,
                                  T* output_data,
                                  const int batch_size,
                                  const int num_head,
                                  const int max_len_this_time,
                                  const int seq_len,
                                  const int head_dim,
                                  const int token_num,
                                  const int* padding_offset,
                                  cudaStream_t cu_stream) {
  // [batch_size, num_head, max_len_this_time, head_dim] -> [token_num, num_head,
  // head_dim]
  constexpr int VEC_16B = 16;
  const int elem_cnt = token_num * num_head * head_dim;
  constexpr int PackSize = VEC_16B / sizeof(T);
  const int32_t pack_num = elem_cnt / PackSize;
  const int32_t block_size = 128;
  int32_t grid_size = (pack_num + block_size - 1) / block_size;
  TransposeRemovingPadding<T, PackSize>
      <<<grid_size, block_size, 0, cu_stream>>>(input_data,
                                                seq_lens,
                                                output_data,
                                                batch_size,
                                                num_head,
                                                max_len_this_time,
                                                seq_len,
                                                head_dim,
                                                token_num,
                                                elem_cnt,
                                                padding_offset);
}

template <paddle::DataType D>
std::vector<paddle::Tensor> apply_transpose_remove_padding(const paddle::Tensor& input, 
                                                           const paddle::Tensor& seq_lens, 
                                                           const paddle::Tensor& padding_offset) {
    typedef PDTraits<D> traits_;
    typedef typename traits_::DataType DataType_;
    typedef typename traits_::data_t data_t;

    auto cu_stream = input.stream();
    std::vector<int64_t> input_shape = input.shape();
    const int bsz = input_shape[0];
    const int num_head = input_shape[1];
    const int seq_len = input_shape[2];
    const int dim_head = input_shape[3];
    const int token_num = padding_offset.shape()[0];

    auto out = paddle::full({token_num, num_head * dim_head}, 0, input.dtype(), input.place());
    InvokeTransposeRemovePadding(
        reinterpret_cast<DataType_*>(const_cast<data_t*>(input.data<data_t>())),
        seq_lens.data<int>(),
        reinterpret_cast<DataType_*>(out.data<data_t>()),
        bsz,
        num_head,
        seq_len,
        seq_len,
        dim_head,
        token_num,
        padding_offset.data<int>(),
        cu_stream
    );
    return {out};
}

std::vector<paddle::Tensor> ApplyTransposeRemovingPadding(const paddle::Tensor& input, 
                                                          const paddle::Tensor& seq_lens, 
                                                          const paddle::Tensor& padding_offset) {
    switch (input.type()) {
        case paddle::DataType::BFLOAT16: {
            return apply_transpose_remove_padding<paddle::DataType::BFLOAT16>(
                input,
                seq_lens,
                padding_offset
            );
        }
        case paddle::DataType::FLOAT16: {
            return apply_transpose_remove_padding<paddle::DataType::FLOAT16>(
                input,
                seq_lens,
                padding_offset
            );
        }
        case paddle::DataType::FLOAT32: {
            return apply_transpose_remove_padding<paddle::DataType::FLOAT32>(
                input,
                seq_lens,
                padding_offset
            );
        }
        default: {
            PD_THROW(
                "NOT supported data type. "
                "Only float16, bfloat16 and float32 are supported. ");
            break;
        }
    }
}

std::vector<std::vector<int64_t>> ApplyTransposeRemovingPaddingInferShape(
        const std::vector<int64_t>& input_shape, 
        const std::vector<int64_t>& seq_lens_shape,
        const std::vector<int64_t>& padding_offset_shape) {
    return {{padding_offset_shape[0], input_shape[1] * input_shape[3]}};
}

std::vector<paddle::DataType> ApplyTransposeRemovingPaddingInferDtype(
        const paddle::DataType& input_dtype, 
        const paddle::DataType& seq_lens_dtype,
        const paddle::DataType& padding_offset_dtype) {
    return {input_dtype};
}

PD_BUILD_OP(transpose_remove_padding)
    .Inputs({"input", "seq_lens", "padding_offset"})
    .Outputs({"fmha_out"})
    .SetKernelFn(PD_KERNEL(ApplyTransposeRemovingPadding))
    .SetInferShapeFn(PD_INFER_SHAPE(ApplyTransposeRemovingPaddingInferShape))
    .SetInferDtypeFn(PD_INFER_DTYPE(ApplyTransposeRemovingPaddingInferDtype));