Commit 78640ad0 authored by Dongz's avatar Dongz Committed by GitHub
Browse files

[feature]: add Wan Sparge infer (#32)



* [feature]: add Wan Sparge infer

* Update scripts/run_wan_t2v_sparge.sh
Co-authored-by: default avatarCopilot <175728472+Copilot@users.noreply.github.com>

* [minor]: fix typo and use config style

* [minor]: remove breakpoint

* [feature]: add all attn class

* [minor]: remove args

* [minor]: remove shared weights

---------
Co-authored-by: default avatarCopilot <175728472+Copilot@users.noreply.github.com>
parent 52166b88
{
"infer_steps": 50,
"target_video_length": 81,
"text_len": 512,
"target_height": 480,
"target_width": 832,
"attention_type": "flash_attn3",
"seed": 42,
"sample_guide_scale": 6,
"sample_shift": 8,
"enable_cfg": true,
"cpu_offload": false,
"sparge": true,
"sparge_ckpt": "configs/shared_weights/sparge_wan2.1_t2v_1.3B.pt"
}
...@@ -2,3 +2,4 @@ from .mm import * ...@@ -2,3 +2,4 @@ from .mm import *
from .norm import * from .norm import *
from .conv import * from .conv import *
from .tensor import * from .tensor import *
from .attn import *
from .attn_weight import *
import torch
import torch.nn as nn
from abc import ABCMeta, abstractmethod
from lightx2v.utils.registry_factory import ATTN_WEIGHT_REGISTER
import torch.nn.functional as F
try:
from spas_sage_attn.autotune import SparseAttentionMeansim
except ImportError:
print("SparseAttentionMeansim not found, please install sparge first")
SparseAttentionMeansim = None
try:
from flash_attn.flash_attn_interface import flash_attn_varlen_func
except ImportError:
print("flash_attn_varlen_func not found, please install flash_attn2 first")
flash_attn_varlen_func = None
try:
from flash_attn_interface import flash_attn_varlen_func as flash_attn_varlen_func_v3
except ImportError:
print("flash_attn_varlen_func_v3 not found, please install flash_attn3 first")
flash_attn_varlen_func_v3 = None
if torch.cuda.get_device_capability(0) == (8, 9):
try:
from sageattention import sageattn_qk_int8_pv_fp16_triton as sageattn
except ImportError:
print("sageattn not found, please install sageattention first")
sageattn = None, None
else:
try:
from sageattention import sageattn
except ImportError:
print("sageattn not found, please install sageattention first")
sageattn = None
class AttnWeightTemplate(metaclass=ABCMeta):
def __init__(self, weight_name):
self.weight_name = weight_name
self.config = {}
def load(self, weight_dict):
pass
@abstractmethod
def apply(self, input_tensor):
pass
def set_config(self, config=None):
if config is not None:
self.config = config
def to_cpu(self, non_blocking=False):
self.weight = self.weight.to("cpu", non_blocking=non_blocking)
def to_cuda(self, non_blocking=False):
self.weight = self.weight.cuda(non_blocking=non_blocking)
@ATTN_WEIGHT_REGISTER("flash_attn2")
class FlashAttn2Weight(AttnWeightTemplate):
def __init__(self):
self.config = {}
def apply(self, q, k, v, cu_seqlens_q=None, cu_seqlens_kv=None, max_seqlen_q=None, max_seqlen_kv=None, model_cls=None):
x = flash_attn_varlen_func(
q,
k,
v,
cu_seqlens_q,
cu_seqlens_kv,
max_seqlen_q,
max_seqlen_kv,
).reshape(max_seqlen_q, -1)
return x
@ATTN_WEIGHT_REGISTER("flash_attn3")
class FlashAttn3Weight(AttnWeightTemplate):
def __init__(self):
self.config = {}
def apply(self, q, k, v, cu_seqlens_q=None, cu_seqlens_kv=None, max_seqlen_q=None, max_seqlen_kv=None, model_cls=None):
x = flash_attn_varlen_func_v3(
q,
k,
v,
cu_seqlens_q,
cu_seqlens_kv,
max_seqlen_q,
max_seqlen_kv,
)[0].reshape(max_seqlen_q, -1)
return x
@ATTN_WEIGHT_REGISTER("sage_attn2")
class SageAttn2Weight(AttnWeightTemplate):
def __init__(self):
self.config = {}
def apply(self, q, k, v, cu_seqlens_q=None, cu_seqlens_kv=None, max_seqlen_q=None, max_seqlen_kv=None, model_cls=None):
q, k, v = q.contiguous(), k.contiguous(), v.contiguous()
if model_cls == "hunyuan":
x1 = sageattn(
q[: cu_seqlens_q[1]].unsqueeze(0),
k[: cu_seqlens_kv[1]].unsqueeze(0),
v[: cu_seqlens_kv[1]].unsqueeze(0),
tensor_layout="NHD",
)
x2 = sageattn(
q[cu_seqlens_q[1] :].unsqueeze(0),
k[cu_seqlens_kv[1] :].unsqueeze(0),
v[cu_seqlens_kv[1] :].unsqueeze(0),
tensor_layout="NHD",
)
x = torch.cat((x1, x2), dim=1)
x = x.view(max_seqlen_q, -1)
elif model_cls in ["wan2.1", "wan2.1_causvid", "wan2.1_df"]:
x = sageattn(
q.unsqueeze(0),
k.unsqueeze(0),
v.unsqueeze(0),
tensor_layout="NHD",
)
x = x.view(max_seqlen_q, -1)
return x
@ATTN_WEIGHT_REGISTER("torch_sdpa")
class TorchSDPAWeight(AttnWeightTemplate):
def __init__(self):
self.config = {}
def apply(self, q, k, v, drop_rate=0, attn_mask=None, causal=False):
q = q.transpose(1, 2)
k = k.transpose(1, 2)
v = v.transpose(1, 2)
if attn_mask is not None and attn_mask.dtype != torch.bool:
attn_mask = attn_mask.to(q.dtype)
x = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, dropout_p=drop_rate, is_causal=causal)
x = x.transpose(1, 2)
b, s, a, d = x.shape
out = x.reshape(b, s, -1)
return out
@ATTN_WEIGHT_REGISTER("Sparge")
class SpargeAttnWeight(AttnWeightTemplate):
def __init__(self, weight_name, verbose=False, l1=0.07, pv_l1=0.08, tune_pv=True, inner_attn_type="flash_attn3"):
self.verbose = (verbose,)
self.l1 = (l1,)
self.pv_l1 = (pv_l1,)
self.tune_pv = (tune_pv,)
self.inner_attn_type = inner_attn_type
self.inner_cls = SparseAttentionMeansim(l1=l1, pv_l1=pv_l1, tune_pv=tune_pv)
super().__init__(weight_name)
def load(self, weight_dict):
# match all key with prefix weight_name
for key in weight_dict.keys():
if key.startswith(self.weight_name):
sub_name = key.split(".")[-1]
setattr(self.inner_cls, sub_name, nn.Parameter(weight_dict[key], requires_grad=False))
def apply(self, q, k, v, cu_seqlens_q=None, cu_seqlens_kv=None, max_seqlen_q=None, max_seqlen_kv=None, model_cls=None):
if len(q.shape) == 3:
q = q.unsqueeze(0)
k = k.unsqueeze(0)
v = v.unsqueeze(0)
x = self.inner_cls(q, k, v, tensor_layout="NHD")
x = x.flatten(2)
x = x.squeeze(0)
return x
...@@ -105,9 +105,7 @@ class WanTransformerInfer: ...@@ -105,9 +105,7 @@ class WanTransformerInfer:
cu_seqlens_q, cu_seqlens_k, lq, lk = self._calculate_q_k_len(q, k, k_lens=seq_lens) cu_seqlens_q, cu_seqlens_k, lq, lk = self._calculate_q_k_len(q, k, k_lens=seq_lens)
if not self.parallel_attention: if not self.parallel_attention:
attn_out = attention( attn_out = weights.self_attn_1.apply(q=q, k=k, v=v, cu_seqlens_q=cu_seqlens_q, cu_seqlens_kv=cu_seqlens_k, max_seqlen_q=lq, max_seqlen_kv=lk, model_cls=self.config["model_cls"])
attention_type=self.attention_type, q=q, k=k, v=v, cu_seqlens_q=cu_seqlens_q, cu_seqlens_kv=cu_seqlens_k, max_seqlen_q=lq, max_seqlen_kv=lk, model_cls=self.config["model_cls"]
)
else: else:
attn_out = self.parallel_attention( attn_out = self.parallel_attention(
attention_type=self.attention_type, attention_type=self.attention_type,
...@@ -136,9 +134,7 @@ class WanTransformerInfer: ...@@ -136,9 +134,7 @@ class WanTransformerInfer:
cu_seqlens_q, cu_seqlens_k, lq, lk = self._calculate_q_k_len(q, k, k_lens=torch.tensor([k.size(0)], dtype=torch.int32, device=k.device)) cu_seqlens_q, cu_seqlens_k, lq, lk = self._calculate_q_k_len(q, k, k_lens=torch.tensor([k.size(0)], dtype=torch.int32, device=k.device))
attn_out = attention( attn_out = weights.cross_attn_1.apply(q=q, k=k, v=v, cu_seqlens_q=cu_seqlens_q, cu_seqlens_kv=cu_seqlens_k, max_seqlen_q=lq, max_seqlen_kv=lk, model_cls=self.config["model_cls"])
attention_type=self.attention_type, q=q, k=k, v=v, cu_seqlens_q=cu_seqlens_q, cu_seqlens_kv=cu_seqlens_k, max_seqlen_q=lq, max_seqlen_kv=lk, model_cls=self.config["model_cls"]
)
if self.task == "i2v": if self.task == "i2v":
k_img = weights.cross_attn_norm_k_img.apply(weights.cross_attn_k_img.apply(context_img)).view(-1, n, d) k_img = weights.cross_attn_norm_k_img.apply(weights.cross_attn_k_img.apply(context_img)).view(-1, n, d)
...@@ -150,7 +146,7 @@ class WanTransformerInfer: ...@@ -150,7 +146,7 @@ class WanTransformerInfer:
k_lens=torch.tensor([k_img.size(0)], dtype=torch.int32, device=k.device), k_lens=torch.tensor([k_img.size(0)], dtype=torch.int32, device=k.device),
) )
img_attn_out = attention( img_attn_out = weights.cross_attn_2.apply(
attention_type=self.attention_type, q=q, k=k_img, v=v_img, cu_seqlens_q=cu_seqlens_q, cu_seqlens_kv=cu_seqlens_k, max_seqlen_q=lq, max_seqlen_kv=lk, model_cls=self.config["model_cls"] attention_type=self.attention_type, q=q, k=k_img, v=v_img, cu_seqlens_q=cu_seqlens_q, cu_seqlens_kv=cu_seqlens_k, max_seqlen_q=lq, max_seqlen_kv=lk, model_cls=self.config["model_cls"]
) )
......
from lightx2v.utils.registry_factory import MM_WEIGHT_REGISTER, LN_WEIGHT_REGISTER, RMS_WEIGHT_REGISTER, TENSOR_REGISTER import torch
from lightx2v.utils.registry_factory import MM_WEIGHT_REGISTER, LN_WEIGHT_REGISTER, RMS_WEIGHT_REGISTER, TENSOR_REGISTER, ATTN_WEIGHT_REGISTER
from lightx2v.common.modules.weight_module import WeightModule, WeightModuleList from lightx2v.common.modules.weight_module import WeightModule, WeightModuleList
...@@ -42,9 +43,29 @@ class WanTransformerAttentionBlock(WeightModule): ...@@ -42,9 +43,29 @@ class WanTransformerAttentionBlock(WeightModule):
self.add_module("ffn_0", MM_WEIGHT_REGISTER[self.mm_type](f"blocks.{self.block_index}.ffn.0.weight", f"blocks.{self.block_index}.ffn.0.bias")) self.add_module("ffn_0", MM_WEIGHT_REGISTER[self.mm_type](f"blocks.{self.block_index}.ffn.0.weight", f"blocks.{self.block_index}.ffn.0.bias"))
self.add_module("ffn_2", MM_WEIGHT_REGISTER[self.mm_type](f"blocks.{self.block_index}.ffn.2.weight", f"blocks.{self.block_index}.ffn.2.bias")) self.add_module("ffn_2", MM_WEIGHT_REGISTER[self.mm_type](f"blocks.{self.block_index}.ffn.2.weight", f"blocks.{self.block_index}.ffn.2.bias"))
# attention weights
if self.config["sparge"]:
assert self.config["sparge_ckpt"], "sparge_ckpt must be set when sparge is True"
self.add_module("self_attn_1", ATTN_WEIGHT_REGISTER["Sparge"](f"blocks.{self.block_index}"))
self.add_module("cross_attn_1", ATTN_WEIGHT_REGISTER[self.config["attention_type"]]())
else:
self.add_module("self_attn_1", ATTN_WEIGHT_REGISTER[self.config["attention_type"]]())
self.add_module("cross_attn_1", ATTN_WEIGHT_REGISTER[self.config["attention_type"]]())
if self.task == "i2v": if self.task == "i2v":
self.add_module("cross_attn_k_img", MM_WEIGHT_REGISTER[self.mm_type](f"blocks.{self.block_index}.cross_attn.k_img.weight", f"blocks.{self.block_index}.cross_attn.k_img.bias")) self.add_module("cross_attn_k_img", MM_WEIGHT_REGISTER[self.mm_type](f"blocks.{self.block_index}.cross_attn.k_img.weight", f"blocks.{self.block_index}.cross_attn.k_img.bias"))
self.add_module("cross_attn_v_img", MM_WEIGHT_REGISTER[self.mm_type](f"blocks.{self.block_index}.cross_attn.v_img.weight", f"blocks.{self.block_index}.cross_attn.v_img.bias")) self.add_module("cross_attn_v_img", MM_WEIGHT_REGISTER[self.mm_type](f"blocks.{self.block_index}.cross_attn.v_img.weight", f"blocks.{self.block_index}.cross_attn.v_img.bias"))
self.add_module("cross_attn_norm_k_img", RMS_WEIGHT_REGISTER["sgl-kernel"](f"blocks.{self.block_index}.cross_attn.norm_k_img.weight")) self.add_module("cross_attn_norm_k_img", RMS_WEIGHT_REGISTER["sgl-kernel"](f"blocks.{self.block_index}.cross_attn.norm_k_img.weight"))
# attention weights
self.add_module("cross_attn_2", ATTN_WEIGHT_REGISTER[self.config["attention_type"]]())
# load attn weights
if self.config["sparge"]:
assert self.config["sparge_ckpt"], "sparge_ckpt must be set when sparge is True"
sparge_ckpt = torch.load(self.config["sparge_ckpt"])
self.self_attn_1.load(sparge_ckpt)
else:
# do not load weights
pass
self.register_parameter("modulation", TENSOR_REGISTER["Default"](f"blocks.{self.block_index}.modulation")) self.register_parameter("modulation", TENSOR_REGISTER["Default"](f"blocks.{self.block_index}.modulation"))
...@@ -45,6 +45,7 @@ class Register(dict): ...@@ -45,6 +45,7 @@ class Register(dict):
MM_WEIGHT_REGISTER = Register() MM_WEIGHT_REGISTER = Register()
ATTN_WEIGHT_REGISTER = Register()
RMS_WEIGHT_REGISTER = Register() RMS_WEIGHT_REGISTER = Register()
LN_WEIGHT_REGISTER = Register() LN_WEIGHT_REGISTER = Register()
CONV3D_WEIGHT_REGISTER = Register() CONV3D_WEIGHT_REGISTER = Register()
......
...@@ -7,7 +7,7 @@ model_path= ...@@ -7,7 +7,7 @@ model_path=
# check section # check section
if [ -z "${CUDA_VISIBLE_DEVICES}" ]; then if [ -z "${CUDA_VISIBLE_DEVICES}" ]; then
cuda_devices=0 cuda_devices=0
echo "Warn: CUDA_VISIBLE_DEVICES is not set, using defalt value: ${cuda_devices}, change at shell script or set env variable." echo "Warn: CUDA_VISIBLE_DEVICES is not set, using default value: ${cuda_devices}, change at shell script or set env variable."
export CUDA_VISIBLE_DEVICES=${cuda_devices} export CUDA_VISIBLE_DEVICES=${cuda_devices}
fi fi
......
...@@ -7,7 +7,7 @@ model_path= ...@@ -7,7 +7,7 @@ model_path=
# check section # check section
if [ -z "${CUDA_VISIBLE_DEVICES}" ]; then if [ -z "${CUDA_VISIBLE_DEVICES}" ]; then
cuda_devices=0 cuda_devices=0
echo "Warn: CUDA_VISIBLE_DEVICES is not set, using defalt value: ${cuda_devices}, change at shell script or set env variable." echo "Warn: CUDA_VISIBLE_DEVICES is not set, using default value: ${cuda_devices}, change at shell script or set env variable."
export CUDA_VISIBLE_DEVICES=${cuda_devices} export CUDA_VISIBLE_DEVICES=${cuda_devices}
fi fi
......
...@@ -7,7 +7,7 @@ model_path= ...@@ -7,7 +7,7 @@ model_path=
# check section # check section
if [ -z "${CUDA_VISIBLE_DEVICES}" ]; then if [ -z "${CUDA_VISIBLE_DEVICES}" ]; then
cuda_devices=0 cuda_devices=0
echo "Warn: CUDA_VISIBLE_DEVICES is not set, using defalt value: ${cuda_devices}, change at shell script or set env variable." echo "Warn: CUDA_VISIBLE_DEVICES is not set, using default value: ${cuda_devices}, change at shell script or set env variable."
export CUDA_VISIBLE_DEVICES=${cuda_devices} export CUDA_VISIBLE_DEVICES=${cuda_devices}
fi fi
......
...@@ -7,7 +7,7 @@ model_path= ...@@ -7,7 +7,7 @@ model_path=
# check section # check section
if [ -z "${CUDA_VISIBLE_DEVICES}" ]; then if [ -z "${CUDA_VISIBLE_DEVICES}" ]; then
cuda_devices=0,1,2,3 cuda_devices=0,1,2,3
echo "Warn: CUDA_VISIBLE_DEVICES is not set, using defalt value: ${cuda_devices}, change at shell script or set env variable." echo "Warn: CUDA_VISIBLE_DEVICES is not set, using default value: ${cuda_devices}, change at shell script or set env variable."
export CUDA_VISIBLE_DEVICES=${cuda_devices} export CUDA_VISIBLE_DEVICES=${cuda_devices}
fi fi
......
...@@ -7,7 +7,7 @@ model_path= ...@@ -7,7 +7,7 @@ model_path=
# check section # check section
if [ -z "${CUDA_VISIBLE_DEVICES}" ]; then if [ -z "${CUDA_VISIBLE_DEVICES}" ]; then
cuda_devices=0 cuda_devices=0
echo "Warn: CUDA_VISIBLE_DEVICES is not set, using defalt value: ${cuda_devices}, change at shell script or set env variable." echo "Warn: CUDA_VISIBLE_DEVICES is not set, using default value: ${cuda_devices}, change at shell script or set env variable."
export CUDA_VISIBLE_DEVICES=${cuda_devices} export CUDA_VISIBLE_DEVICES=${cuda_devices}
fi fi
......
...@@ -7,7 +7,7 @@ model_path= ...@@ -7,7 +7,7 @@ model_path=
# check section # check section
if [ -z "${CUDA_VISIBLE_DEVICES}" ]; then if [ -z "${CUDA_VISIBLE_DEVICES}" ]; then
cuda_devices=0 cuda_devices=0
echo "Warn: CUDA_VISIBLE_DEVICES is not set, using defalt value: ${cuda_devices}, change at shell script or set env variable." echo "Warn: CUDA_VISIBLE_DEVICES is not set, using default value: ${cuda_devices}, change at shell script or set env variable."
export CUDA_VISIBLE_DEVICES=${cuda_devices} export CUDA_VISIBLE_DEVICES=${cuda_devices}
fi fi
......
...@@ -7,7 +7,7 @@ model_path= ...@@ -7,7 +7,7 @@ model_path=
# check section # check section
if [ -z "${CUDA_VISIBLE_DEVICES}" ]; then if [ -z "${CUDA_VISIBLE_DEVICES}" ]; then
cuda_devices=0 cuda_devices=0
echo "Warn: CUDA_VISIBLE_DEVICES is not set, using defalt value: ${cuda_devices}, change at shell script or set env variable." echo "Warn: CUDA_VISIBLE_DEVICES is not set, using default value: ${cuda_devices}, change at shell script or set env variable."
export CUDA_VISIBLE_DEVICES=${cuda_devices} export CUDA_VISIBLE_DEVICES=${cuda_devices}
fi fi
......
...@@ -7,7 +7,7 @@ model_path= ...@@ -7,7 +7,7 @@ model_path=
# check section # check section
if [ -z "${CUDA_VISIBLE_DEVICES}" ]; then if [ -z "${CUDA_VISIBLE_DEVICES}" ]; then
cuda_devices=0 cuda_devices=0
echo "Warn: CUDA_VISIBLE_DEVICES is not set, using defalt value: ${cuda_devices}, change at shell script or set env variable." echo "Warn: CUDA_VISIBLE_DEVICES is not set, using default value: ${cuda_devices}, change at shell script or set env variable."
export CUDA_VISIBLE_DEVICES=${cuda_devices} export CUDA_VISIBLE_DEVICES=${cuda_devices}
fi fi
......
...@@ -8,7 +8,7 @@ lora_path= ...@@ -8,7 +8,7 @@ lora_path=
# check section # check section
if [ -z "${CUDA_VISIBLE_DEVICES}" ]; then if [ -z "${CUDA_VISIBLE_DEVICES}" ]; then
cuda_devices=0 cuda_devices=0
echo "Warn: CUDA_VISIBLE_DEVICES is not set, using defalt value: ${cuda_devices}, change at shell script or set env variable." echo "Warn: CUDA_VISIBLE_DEVICES is not set, using default value: ${cuda_devices}, change at shell script or set env variable."
export CUDA_VISIBLE_DEVICES=${cuda_devices} export CUDA_VISIBLE_DEVICES=${cuda_devices}
fi fi
......
...@@ -7,7 +7,7 @@ model_path="/mnt/Text2Video/wangshankun/HF_Cache/hub/models--Skywork--SkyReels-V ...@@ -7,7 +7,7 @@ model_path="/mnt/Text2Video/wangshankun/HF_Cache/hub/models--Skywork--SkyReels-V
# check section # check section
if [ -z "${CUDA_VISIBLE_DEVICES}" ]; then if [ -z "${CUDA_VISIBLE_DEVICES}" ]; then
cuda_devices=0 cuda_devices=0
echo "Warn: CUDA_VISIBLE_DEVICES is not set, using defalt value: ${cuda_devices}, change at shell script or set env variable." echo "Warn: CUDA_VISIBLE_DEVICES is not set, using default value: ${cuda_devices}, change at shell script or set env variable."
export CUDA_VISIBLE_DEVICES=${cuda_devices} export CUDA_VISIBLE_DEVICES=${cuda_devices}
fi fi
......
...@@ -6,7 +6,7 @@ model_path= ...@@ -6,7 +6,7 @@ model_path=
# check section # check section
if [ -z "${CUDA_VISIBLE_DEVICES}" ]; then if [ -z "${CUDA_VISIBLE_DEVICES}" ]; then
cuda_devices=0 cuda_devices=0
echo "Warn: CUDA_VISIBLE_DEVICES is not set, using defalt value: ${cuda_devices}, change at shell script or set env variable." echo "Warn: CUDA_VISIBLE_DEVICES is not set, using default value: ${cuda_devices}, change at shell script or set env variable."
export CUDA_VISIBLE_DEVICES=${cuda_devices} export CUDA_VISIBLE_DEVICES=${cuda_devices}
fi fi
......
...@@ -7,7 +7,7 @@ model_path= ...@@ -7,7 +7,7 @@ model_path=
# check section # check section
if [ -z "${CUDA_VISIBLE_DEVICES}" ]; then if [ -z "${CUDA_VISIBLE_DEVICES}" ]; then
cuda_devices=0 cuda_devices=0
echo "Warn: CUDA_VISIBLE_DEVICES is not set, using defalt value: ${cuda_devices}, change at shell script or set env variable." echo "Warn: CUDA_VISIBLE_DEVICES is not set, using default value: ${cuda_devices}, change at shell script or set env variable."
export CUDA_VISIBLE_DEVICES=${cuda_devices} export CUDA_VISIBLE_DEVICES=${cuda_devices}
fi fi
......
...@@ -7,7 +7,7 @@ model_path= ...@@ -7,7 +7,7 @@ model_path=
# check section # check section
if [ -z "${CUDA_VISIBLE_DEVICES}" ]; then if [ -z "${CUDA_VISIBLE_DEVICES}" ]; then
cuda_devices=0 cuda_devices=0
echo "Warn: CUDA_VISIBLE_DEVICES is not set, using defalt value: ${cuda_devices}, change at shell script or set env variable." echo "Warn: CUDA_VISIBLE_DEVICES is not set, using default value: ${cuda_devices}, change at shell script or set env variable."
export CUDA_VISIBLE_DEVICES=${cuda_devices} export CUDA_VISIBLE_DEVICES=${cuda_devices}
fi fi
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment