"docs/api/c/normalization.rst" did not exist on "996ea169cf42ba887437760d2001b26812ee95bc"
packing.cpp 3.52 KB
Newer Older
1
2
3
4
5
6
/*************************************************************************
 * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
 *
 * See LICENSE for license information.
 ************************************************************************/

7
#include "extensions.h"
8
9
10
11
12
13

namespace transformer_engine {
namespace jax {

pybind11::bytes PackCustomCallCommonDescriptor(const std::vector<size_t> &shape, DType in_dtype,
                                               DType out_dtype, size_t act_enum) {
14
15
16
17
18
19
  CustomCallCommonDescriptor desc{};
  desc.shape.from_vector(shape);
  desc.in_dtype = in_dtype;
  desc.out_dtype = out_dtype;
  desc.act_enum = act_enum;
  return PackOpaque(desc);
20
21
22
23
}

pybind11::bytes PackCustomCallCommonWkDescriptor(const std::vector<size_t> &shape,
                                                 const std::vector<size_t> &wkshape, DType in_dtype,
24
25
26
27
28
29
30
31
32
                                                 DType out_dtype, DType wk_dtype, size_t act_enum) {
  CustomCallCommonWkDescriptor desc{};
  desc.shape.from_vector(shape);
  desc.wkshape.from_vector(wkshape);
  desc.in_dtype = in_dtype;
  desc.out_dtype = out_dtype;
  desc.wk_dtype = wk_dtype;
  desc.act_enum = act_enum;
  return PackOpaque(desc);
33
34
}

35
36
37
38
pybind11::bytes PackCustomCallNormDescriptor(size_t batch_size, size_t hidden_size,
                                             size_t wkspace_size, DType x_dtype, DType w_dtype,
                                             DType wkspace_dtype, bool zero_centered_gamma,
                                             float eps, int sm_margin) {
39
40
41
42
43
44
45
46
47
48
49
  CustomCallNormDescriptor desc{};
  desc.batch_size = batch_size;
  desc.hidden_size = hidden_size;
  desc.wkspace_size = wkspace_size;
  desc.x_dtype = x_dtype;
  desc.w_dtype = w_dtype;
  desc.wkspace_dtype = wkspace_dtype;
  desc.zero_centered_gamma = zero_centered_gamma;
  desc.eps = eps;
  desc.sm_margin = sm_margin;
  return PackOpaque(desc);
50
51
52
53
54
}

pybind11::bytes PackCustomCallSoftmaxDescriptor(size_t batch_size, size_t padding_size,
                                                size_t head_dim, size_t q_seqlen, size_t k_seqlen,
                                                DType dtype, float scale_factor) {
55
56
  return PackOpaque(SoftmaxDescriptor{batch_size, padding_size, head_dim, q_seqlen, k_seqlen, dtype,
                                      scale_factor});
57
58
59
60
61
}

pybind11::bytes PackCustomCallFusedAttnDescriptor(
    size_t input_batch, size_t bias_batch, size_t q_max_seqlen, size_t kv_max_seqlen,
    size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t head_dim,
62
63
    size_t max_segments_per_seq, size_t wkspace_size, float scaling_factor,
    float dropout_probability, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
64
    NVTE_QKV_Layout qkv_layout, DType dtype, DType wkspace_dtype, bool is_training,
65
66
67
68
69
70
71
72
73
    bool deterministic, int64_t window_size_left, int64_t window_size_right) {
  return PackOpaque(
      CustomCallFusedAttnDescriptor{input_batch,   bias_batch,       q_max_seqlen,
                                    kv_max_seqlen, attn_heads,       num_gqa_groups,
                                    bias_heads,    head_dim,         max_segments_per_seq,
                                    wkspace_size,  scaling_factor,   dropout_probability,
                                    bias_type,     mask_type,        qkv_layout,
                                    dtype,         wkspace_dtype,    is_training,
                                    deterministic, window_size_left, window_size_right});
74
75
}

76
77
}  // namespace jax
}  // namespace transformer_engine