utils.cu 1.78 KB
Newer Older
1
/*************************************************************************
2
 * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
4
5
6
7
8
 *
 * See LICENSE for license information.
 ************************************************************************/
#include <cuda_runtime_api.h>
#include <cassert>

9
#include "common/util/cuda_runtime.h"
10
11
12
13
14
15
16
17
18
19
20
#include "utils.h"

namespace transformer_engine {
namespace jax {

int GetCudaRuntimeVersion() {
    int ver = 0;
    NVTE_CHECK_CUDA(cudaRuntimeGetVersion(&ver));
    return ver;
}

21
int GetDeviceComputeCapability(int gpu_id) { return transformer_engine::cuda::sm_arch(gpu_id); }
22

23
24
25
26
27
28
29
30
31
__global__ void populate_rng_state_kernel(int64_t *rng_state_dst, const int64_t *const seed,
                                          int64_t offset) {
    int tid = blockIdx.x * blockDim.x + threadIdx.x;
    if (tid > 0) return;
    rng_state_dst[0] = seed[0];
    rng_state_dst[1] = offset;
}

void PopulateRngStateAsync(void *rng_state_dst, const void *const seed, size_t q_max_seqlen,
32
33
34
35
36
37
38
39
40
                           size_t kv_max_seqlen, NVTE_Fused_Attn_Backend backend,
                           cudaStream_t stream) {
    size_t increment = 0;
    if (backend == NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) {
        increment = 16;
    } else {
        constexpr int threads_per_cta = 128;
        increment = (q_max_seqlen * kv_max_seqlen + threads_per_cta - 1) / threads_per_cta;
    }
41
42
43
44
45
46
    auto offset = FusedAttnOffsetManager::Instance().GetAndUpdateOffset(increment);
    populate_rng_state_kernel<<<1, 1, 0, stream>>>(reinterpret_cast<int64_t *>(rng_state_dst),
                                                   reinterpret_cast<const int64_t *>(seed), offset);
    NVTE_CHECK_CUDA(cudaGetLastError());
}

47
48
}  // namespace jax
}  // namespace transformer_engine