Unverified Commit 99f4156f authored by q.yao's avatar q.yao Committed by GitHub
Browse files

Fix cuda reinitialization in a multiprocessing setting (#862)

parent 20d8f47a
......@@ -83,12 +83,18 @@ class TurbomindModelConfig:
return True
_WEIGHT_DTYPE_MAP = dict(
int4=torch.float16,
fp16=torch.float16,
fp32=torch.float16,
bf16=torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16,
)
def _weight_dtype_map(weight_type: str, default=None):
"""get weight dtype map."""
_WEIGHT_DTYPE_MAP = dict(
int4=torch.float16,
fp16=torch.float16,
fp32=torch.float16,
bf16=torch.bfloat16
if torch.cuda.is_bf16_supported() else torch.float16,
)
return _WEIGHT_DTYPE_MAP.get(weight_type, default)
class BaseOutputModel(ABC):
......@@ -153,8 +159,8 @@ class BaseOutputModel(ABC):
if self.to_file:
if torch.is_floating_point(param):
torch_type = _WEIGHT_DTYPE_MAP.get(self.cfg.weight_type,
torch.float16)
torch_type = _weight_dtype_map(self.cfg.weight_type,
torch.float16)
param = param.to(torch_type)
tprint(name, param.shape)
_tofile(param, osp.join(self.out_dir, name))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment