"vscode:/vscode.git/clone" did not exist on "6a887cf5b6f128cb22d0e08de11307fe6b4ad359"
Commit 98957dd7 authored by luopl's avatar luopl
Browse files

init

parents
Pipeline #1625 canceled with stages
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
import json
import torch
from transformers.utils import WEIGHTS_NAME, CONFIG_NAME
from transformers.utils.hub import cached_file
def load_config_hf(model_name):
resolved_archive_file = cached_file(model_name, CONFIG_NAME, _raise_exceptions_for_missing_entries=False)
return json.load(open(resolved_archive_file))
def load_state_dict_hf(model_name, device=None, dtype=None):
# If not fp32, then we don't want to load directly to the GPU
mapped_device = "cpu" if dtype not in [torch.float32, None] else device
resolved_archive_file = cached_file(model_name, WEIGHTS_NAME, _raise_exceptions_for_missing_entries=False)
return torch.load(resolved_archive_file, map_location=mapped_device)
# Convert dtype before moving to GPU to save memory
if dtype is not None:
state_dict = {k: v.to(dtype=dtype) for k, v in state_dict.items()}
state_dict = {k: v.to(device=device) for k, v in state_dict.items()}
return state_dict
# 模型唯一标识
modelCode=943
# 模型名称
modelName=mamba2_pytorch
# 模型描述
modelDescription=Transformers are SSMs: Generalized Models and Efficient Algorithms Through Structured State Space Duality
# 应用场景
appScenario=推理,科研,制造,医疗,家居,教育
# 框架类型
frameType=pytorch
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment