packing.cpp 3.49 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
/*************************************************************************
 * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
 *
 * See LICENSE for license information.
 ************************************************************************/

#include "jax/csrc/extensions.h"

namespace transformer_engine {
namespace jax {

pybind11::bytes PackCustomCallCommonDescriptor(const std::vector<size_t> &shape, DType in_dtype,
                                               DType out_dtype, size_t act_enum) {
14
15
16
17
18
19
  CustomCallCommonDescriptor desc{};
  desc.shape.from_vector(shape);
  desc.in_dtype = in_dtype;
  desc.out_dtype = out_dtype;
  desc.act_enum = act_enum;
  return PackOpaque(desc);
20
21
22
23
}

pybind11::bytes PackCustomCallCommonWkDescriptor(const std::vector<size_t> &shape,
                                                 const std::vector<size_t> &wkshape, DType in_dtype,
24
25
26
27
28
29
30
31
32
                                                 DType out_dtype, DType wk_dtype, size_t act_enum) {
  CustomCallCommonWkDescriptor desc{};
  desc.shape.from_vector(shape);
  desc.wkshape.from_vector(wkshape);
  desc.in_dtype = in_dtype;
  desc.out_dtype = out_dtype;
  desc.wk_dtype = wk_dtype;
  desc.act_enum = act_enum;
  return PackOpaque(desc);
33
34
35
36
37
38
39
}

pybind11::bytes PackCustomCallNormDescriptor(
    size_t batch_size, size_t hidden_size, size_t wkspace_size, size_t barrier_size,
    const std::vector<size_t> &dgamma_part_shape, const std::vector<size_t> &dbeta_part_shape,
    DType x_dtype, DType w_dtype, DType wkspace_dtype, DType barrier_dtype, DType dgamma_part_dtype,
    DType dbeta_part_dtype, bool zero_centered_gamma, float eps, int sm_margin) {
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
  CustomCallNormDescriptor desc{};
  desc.batch_size = batch_size;
  desc.hidden_size = hidden_size;
  desc.wkspace_size = wkspace_size;
  desc.barrier_size = barrier_size;
  desc.dgamma_part_shape.from_vector(dgamma_part_shape);
  desc.dbeta_part_shape.from_vector(dbeta_part_shape);
  desc.x_dtype = x_dtype;
  desc.w_dtype = w_dtype;
  desc.wkspace_dtype = wkspace_dtype;
  desc.barrier_dtype = barrier_dtype;
  desc.dgamma_part_dtype = dgamma_part_dtype;
  desc.dbeta_part_dtype = dbeta_part_dtype;
  desc.zero_centered_gamma = zero_centered_gamma;
  desc.eps = eps;
  desc.sm_margin = sm_margin;
  return PackOpaque(desc);
57
58
59
60
61
}

pybind11::bytes PackCustomCallSoftmaxDescriptor(size_t batch_size, size_t padding_size,
                                                size_t head_dim, size_t q_seqlen, size_t k_seqlen,
                                                DType dtype, float scale_factor) {
62
63
  return PackOpaque(SoftmaxDescriptor{batch_size, padding_size, head_dim, q_seqlen, k_seqlen, dtype,
                                      scale_factor});
64
65
66
67
68
}

pybind11::bytes PackCustomCallFusedAttnDescriptor(
    size_t input_batch, size_t bias_batch, size_t q_max_seqlen, size_t kv_max_seqlen,
    size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t head_dim,
69
70
71
    size_t max_segments_per_seq, size_t wkspace_size, float scaling_factor,
    float dropout_probability, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
    NVTE_QKV_Layout qkv_layout, DType dtype, DType wkspace_dtype, bool is_training) {
72
73
  return PackOpaque(CustomCallFusedAttnDescriptor{
      input_batch, bias_batch, q_max_seqlen, kv_max_seqlen, attn_heads, num_gqa_groups, bias_heads,
74
75
      head_dim, max_segments_per_seq, wkspace_size, scaling_factor, dropout_probability, bias_type,
      mask_type, qkv_layout, dtype, wkspace_dtype, is_training});
76
77
}

78
79
}  // namespace jax
}  // namespace transformer_engine