utils.cu 2.65 KB
Newer Older
1
/*************************************************************************
2
 * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
4
5
6
 *
 * See LICENSE for license information.
 ************************************************************************/
#include <cuda_runtime_api.h>
7

8
9
#include <cassert>

10
#include "common/util/cuda_runtime.h"
11
12
13
14
15
16
#include "utils.h"

namespace transformer_engine {
namespace jax {

int GetCudaRuntimeVersion() {
17
18
19
  int ver = 0;
  NVTE_CHECK_CUDA(cudaRuntimeGetVersion(&ver));
  return ver;
20
21
}

22
int GetDeviceComputeCapability(int gpu_id) { return transformer_engine::cuda::sm_arch(gpu_id); }
23

24
25
__global__ void populate_rng_state_kernel(int64_t *rng_state_dst, const int64_t *const seed,
                                          int64_t offset) {
26
27
28
29
  int tid = blockIdx.x * blockDim.x + threadIdx.x;
  if (tid > 0) return;
  rng_state_dst[0] = seed[0];
  rng_state_dst[1] = offset;
30
31
32
}

void PopulateRngStateAsync(void *rng_state_dst, const void *const seed, size_t q_max_seqlen,
33
34
                           size_t kv_max_seqlen, NVTE_Fused_Attn_Backend backend,
                           cudaStream_t stream) {
35
36
37
38
39
40
41
42
43
44
45
  size_t increment = 0;
  if (backend == NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) {
    increment = 16;
  } else {
    constexpr int threads_per_cta = 128;
    increment = (q_max_seqlen * kv_max_seqlen + threads_per_cta - 1) / threads_per_cta;
  }
  auto offset = FusedAttnOffsetManager::Instance().GetAndUpdateOffset(increment);
  populate_rng_state_kernel<<<1, 1, 0, stream>>>(reinterpret_cast<int64_t *>(rng_state_dst),
                                                 reinterpret_cast<const int64_t *>(seed), offset);
  NVTE_CHECK_CUDA(cudaGetLastError());
46
47
}

48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
__global__ void get_runtime_num_segments_kernel(int32_t *cu_seqlen, size_t len, uint32_t *out) {
  int tid = blockDim.x * blockIdx.x + threadIdx.x;
  if (tid >= len) return;

  if (cu_seqlen[tid] > 0) {
    // atomicAdd only support 32 bits dtype
    atomicAdd(out, 1);
  }
}

uint32_t GetRuntimeNumSegments(void *cu_seqlen, void *workspace, size_t len, cudaStream_t stream) {
  // workspace size requires 4 bytes
  uint32_t *dout = static_cast<uint32_t *>(workspace);
  uint32_t hout{};
  cudaMemsetAsync(dout, 0, sizeof(uint32_t), stream);
  constexpr int threads = 128;
  const int blocks = (len - 1) / threads + 1;
  get_runtime_num_segments_kernel<<<blocks, threads, 0, stream>>>(static_cast<int32_t *>(cu_seqlen),
                                                                  len, dout);
  cudaMemcpyAsync(&hout, dout, sizeof(uint32_t), cudaMemcpyDeviceToHost, stream);
  cudaStreamSynchronize(stream);
  return hout;
}

72
73
}  // namespace jax
}  // namespace transformer_engine