"git@developer.sourcefind.cn:zhaoyu6/sglang.git" did not exist on "ad47749b827b8087c914d489d2d26ac485121c59"
utils.py 749 Bytes
Newer Older
xingjinliang's avatar
xingjinliang committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
# Copyright (c) 2024, NVIDIA CORPORATION.  All rights reserved.

import os

import torch


def get_config_path(project_dir: str) -> str:
    """Config copy stored within retro project dir."""
    return os.path.join(project_dir, "config.json")


def get_gpt_data_dir(project_dir: str) -> str:
    """Get project-relative directory of GPT bin/idx datasets."""
    return os.path.join(project_dir, "data")


# ** Note ** : Retro's compatibility between cross attention and Flash/Fused
#   Attention is currently a work in progress. We default to returning None for
#   now.
# def get_all_true_mask(size, device):
#     return torch.full(size=size, fill_value=True, dtype=torch.bool, device=device)
def get_all_true_mask(size, device):
    return None