Commit 88b96b2a authored by sxtyzhangzk's avatar sxtyzhangzk Committed by muyangli
Browse files

[major] fix error when device_id != 0

parent 92d283a2
......@@ -774,7 +774,7 @@ FluxModel::FluxModel(bool use_fp4, bool offload, Tensor::ScalarType dtype, Devic
}
}
for (int i = 0; i < 38; i++) {
single_transformer_blocks.push_back(std::make_unique<FluxSingleTransformerBlock>(3072, 24, 3072, 4, use_fp4, dtype, Device::cuda()));
single_transformer_blocks.push_back(std::make_unique<FluxSingleTransformerBlock>(3072, 24, 3072, 4, use_fp4, dtype, device));
registerChildren(*single_transformer_blocks.back(), format("single_transformer_blocks.{}", i));
if (offload) {
single_transformer_blocks.back()->setLazyLoad(true);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment