Unverified Commit 37a27712 authored by Muyang Li's avatar Muyang Li Committed by GitHub
Browse files

Merge pull request #340 from mit-han-lab/dev

feat: support PuLID, Double FBCache and TeaCache; better linter
parents c1d6fc84 760ab022
This diff is collapsed.
...@@ -20,36 +20,59 @@ public: ...@@ -20,36 +20,59 @@ public:
ModuleWrapper::init(deviceId); ModuleWrapper::init(deviceId);
CUDADeviceContext ctx(this->deviceId); CUDADeviceContext ctx(this->deviceId);
net = std::make_unique<FluxModel>(use_fp4, offload, bf16 ? Tensor::BF16 : Tensor::FP16, Device::cuda((int)deviceId)); net = std::make_unique<FluxModel>(
use_fp4, offload, bf16 ? Tensor::BF16 : Tensor::FP16, Device::cuda((int)deviceId));
} }
bool isBF16() { bool isBF16() {
checkModel(); checkModel();
return net->dtype == Tensor::BF16; return net->dtype == Tensor::BF16;
} }
pybind11::function residual_callback;
void set_residual_callback(pybind11::function callback) {
pybind11::gil_scoped_acquire gil;
if (!callback || callback.is_none()) {
residual_callback = pybind11::function();
if (net) {
net->set_residual_callback(nullptr);
}
return;
}
residual_callback = std::move(callback);
if (net) {
pybind11::object cb = residual_callback;
net->set_residual_callback([cb](const Tensor &x) -> Tensor {
pybind11::gil_scoped_acquire gil;
torch::Tensor torch_x = to_torch(x);
pybind11::object result = cb(torch_x);
torch::Tensor torch_y = result.cast<torch::Tensor>();
Tensor y = from_torch(torch_y);
return y;
});
} else {
}
}
torch::Tensor forward( torch::Tensor forward(torch::Tensor hidden_states,
torch::Tensor hidden_states, torch::Tensor encoder_hidden_states,
torch::Tensor encoder_hidden_states, torch::Tensor temb,
torch::Tensor temb, torch::Tensor rotary_emb_img,
torch::Tensor rotary_emb_img, torch::Tensor rotary_emb_context,
torch::Tensor rotary_emb_context, torch::Tensor rotary_emb_single,
torch::Tensor rotary_emb_single, std::optional<torch::Tensor> controlnet_block_samples = std::nullopt,
std::optional<torch::Tensor> controlnet_block_samples = std::nullopt, std::optional<torch::Tensor> controlnet_single_block_samples = std::nullopt,
std::optional<torch::Tensor> controlnet_single_block_samples = std::nullopt, bool skip_first_layer = false) {
bool skip_first_layer = false)
{
checkModel(); checkModel();
CUDADeviceContext ctx(deviceId); CUDADeviceContext ctx(deviceId);
spdlog::debug("QuantizedFluxModel forward"); spdlog::debug("QuantizedFluxModel forward");
hidden_states = hidden_states.contiguous(); hidden_states = hidden_states.contiguous();
encoder_hidden_states = encoder_hidden_states.contiguous(); encoder_hidden_states = encoder_hidden_states.contiguous();
temb = temb.contiguous(); temb = temb.contiguous();
rotary_emb_img = rotary_emb_img.contiguous(); rotary_emb_img = rotary_emb_img.contiguous();
rotary_emb_context = rotary_emb_context.contiguous(); rotary_emb_context = rotary_emb_context.contiguous();
rotary_emb_single = rotary_emb_single.contiguous(); rotary_emb_single = rotary_emb_single.contiguous();
Tensor result = net->forward( Tensor result = net->forward(
from_torch(hidden_states), from_torch(hidden_states),
...@@ -59,9 +82,10 @@ public: ...@@ -59,9 +82,10 @@ public:
from_torch(rotary_emb_context), from_torch(rotary_emb_context),
from_torch(rotary_emb_single), from_torch(rotary_emb_single),
controlnet_block_samples.has_value() ? from_torch(controlnet_block_samples.value().contiguous()) : Tensor{}, controlnet_block_samples.has_value() ? from_torch(controlnet_block_samples.value().contiguous()) : Tensor{},
controlnet_single_block_samples.has_value() ? from_torch(controlnet_single_block_samples.value().contiguous()) : Tensor{}, controlnet_single_block_samples.has_value()
skip_first_layer ? from_torch(controlnet_single_block_samples.value().contiguous())
); : Tensor{},
skip_first_layer);
torch::Tensor output = to_torch(result); torch::Tensor output = to_torch(result);
Tensor::synchronizeDevice(); Tensor::synchronizeDevice();
...@@ -69,25 +93,24 @@ public: ...@@ -69,25 +93,24 @@ public:
return output; return output;
} }
std::tuple<torch::Tensor, torch::Tensor> forward_layer( std::tuple<torch::Tensor, torch::Tensor>
int64_t idx, forward_layer(int64_t idx,
torch::Tensor hidden_states, torch::Tensor hidden_states,
torch::Tensor encoder_hidden_states, torch::Tensor encoder_hidden_states,
torch::Tensor temb, torch::Tensor temb,
torch::Tensor rotary_emb_img, torch::Tensor rotary_emb_img,
torch::Tensor rotary_emb_context, torch::Tensor rotary_emb_context,
std::optional<torch::Tensor> controlnet_block_samples = std::nullopt, std::optional<torch::Tensor> controlnet_block_samples = std::nullopt,
std::optional<torch::Tensor> controlnet_single_block_samples = std::nullopt) std::optional<torch::Tensor> controlnet_single_block_samples = std::nullopt) {
{
CUDADeviceContext ctx(deviceId); CUDADeviceContext ctx(deviceId);
spdlog::debug("QuantizedFluxModel forward_layer {}", idx); spdlog::debug("QuantizedFluxModel forward_layer {}", idx);
hidden_states = hidden_states.contiguous(); hidden_states = hidden_states.contiguous();
encoder_hidden_states = encoder_hidden_states.contiguous(); encoder_hidden_states = encoder_hidden_states.contiguous();
temb = temb.contiguous(); temb = temb.contiguous();
rotary_emb_img = rotary_emb_img.contiguous(); rotary_emb_img = rotary_emb_img.contiguous();
rotary_emb_context = rotary_emb_context.contiguous(); rotary_emb_context = rotary_emb_context.contiguous();
auto &&[hidden_states_, encoder_hidden_states_] = net->forward_layer( auto &&[hidden_states_, encoder_hidden_states_] = net->forward_layer(
idx, idx,
...@@ -97,35 +120,31 @@ public: ...@@ -97,35 +120,31 @@ public:
from_torch(rotary_emb_img), from_torch(rotary_emb_img),
from_torch(rotary_emb_context), from_torch(rotary_emb_context),
controlnet_block_samples.has_value() ? from_torch(controlnet_block_samples.value().contiguous()) : Tensor{}, controlnet_block_samples.has_value() ? from_torch(controlnet_block_samples.value().contiguous()) : Tensor{},
controlnet_single_block_samples.has_value() ? from_torch(controlnet_single_block_samples.value().contiguous()) : Tensor{} controlnet_single_block_samples.has_value()
); ? from_torch(controlnet_single_block_samples.value().contiguous())
: Tensor{});
hidden_states = to_torch(hidden_states_); hidden_states = to_torch(hidden_states_);
encoder_hidden_states = to_torch(encoder_hidden_states_); encoder_hidden_states = to_torch(encoder_hidden_states_);
Tensor::synchronizeDevice(); Tensor::synchronizeDevice();
return { hidden_states, encoder_hidden_states }; return {hidden_states, encoder_hidden_states};
} }
torch::Tensor forward_single_layer( torch::Tensor forward_single_layer(int64_t idx,
int64_t idx, torch::Tensor hidden_states,
torch::Tensor hidden_states, torch::Tensor temb,
torch::Tensor temb, torch::Tensor rotary_emb_single) {
torch::Tensor rotary_emb_single)
{
CUDADeviceContext ctx(deviceId); CUDADeviceContext ctx(deviceId);
spdlog::debug("QuantizedFluxModel forward_single_layer {}", idx); spdlog::debug("QuantizedFluxModel forward_single_layer {}", idx);
hidden_states = hidden_states.contiguous(); hidden_states = hidden_states.contiguous();
temb = temb.contiguous(); temb = temb.contiguous();
rotary_emb_single = rotary_emb_single.contiguous(); rotary_emb_single = rotary_emb_single.contiguous();
Tensor result = net->single_transformer_blocks.at(idx)->forward( Tensor result = net->single_transformer_blocks.at(idx)->forward(
from_torch(hidden_states), from_torch(hidden_states), from_torch(temb), from_torch(rotary_emb_single));
from_torch(temb),
from_torch(rotary_emb_single)
);
hidden_states = to_torch(result); hidden_states = to_torch(result);
Tensor::synchronizeDevice(); Tensor::synchronizeDevice();
...@@ -133,6 +152,18 @@ public: ...@@ -133,6 +152,18 @@ public:
return hidden_states; return hidden_states;
} }
// expose the norm1 forward method of the transformer blocks
// this is used by TeaCache to get the norm1 output
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
norm_one_forward(int64_t idx, torch::Tensor hidden_states, torch::Tensor temb) {
AdaLayerNormZero::Output result =
net->transformer_blocks.at(idx)->norm1.forward(from_torch(hidden_states), from_torch(temb));
return {to_torch(result.x),
to_torch(result.gate_msa),
to_torch(result.shift_mlp),
to_torch(result.scale_mlp),
to_torch(result.gate_mlp)};
}
// must be called after loading lora // must be called after loading lora
// skip specific ranks in W4A4 layers // skip specific ranks in W4A4 layers
...@@ -174,5 +205,4 @@ public: ...@@ -174,5 +205,4 @@ public:
throw std::invalid_argument(spdlog::fmt_lib::format("Invalid attention implementation {}", name)); throw std::invalid_argument(spdlog::fmt_lib::format("Invalid attention implementation {}", name));
} }
} }
};
};
\ No newline at end of file
...@@ -16,7 +16,12 @@ public: ...@@ -16,7 +16,12 @@ public:
checkCUDA(cudaDeviceGetLimit(&val, cudaLimitStackSize)); checkCUDA(cudaDeviceGetLimit(&val, cudaLimitStackSize));
spdlog::debug("Stack={}", val); spdlog::debug("Stack={}", val);
net = std::make_unique<GEMM_W4A4>((int)in_features, (int)out_features, bias, use_fp4, bf16 ? Tensor::BF16 : Tensor::FP16, Device::cuda((int)deviceId)); net = std::make_unique<GEMM_W4A4>((int)in_features,
(int)out_features,
bias,
use_fp4,
bf16 ? Tensor::BF16 : Tensor::FP16,
Device::cuda((int)deviceId));
} }
torch::Tensor forward(torch::Tensor x) { torch::Tensor forward(torch::Tensor x) {
...@@ -53,11 +58,11 @@ public: ...@@ -53,11 +58,11 @@ public:
// activation: row major, [M / BLOCK_M, K / WARP_K, NUM_WARPS, WARP_M_TILES, WARP_SIZE] of packed_act_t (uint4) // activation: row major, [M / BLOCK_M, K / WARP_K, NUM_WARPS, WARP_M_TILES, WARP_SIZE] of packed_act_t (uint4)
constexpr int BLOCK_M = 256; constexpr int BLOCK_M = 256;
constexpr int WARP_K = 64; constexpr int WARP_K = 64;
constexpr int NUM_WARPS = 8; constexpr int NUM_WARPS = 8;
constexpr int WARP_M_TILES = 2; constexpr int WARP_M_TILES = 2;
constexpr int WARP_SIZE = 32; constexpr int WARP_SIZE = 32;
std::stringstream ss; std::stringstream ss;
for (int bm = 0; bm < M / BLOCK_M; bm++) { for (int bm = 0; bm < M / BLOCK_M; bm++) {
...@@ -95,13 +100,10 @@ public: ...@@ -95,13 +100,10 @@ public:
x = x.contiguous(); x = x.contiguous();
auto qout = net->quantize( auto qout = net->quantize(from_torch(x), fuse_glu);
from_torch(x),
fuse_glu
);
Tensor act = qout.act.copy(Device::cpu()); Tensor act = qout.act.copy(Device::cpu());
Tensor ascales = qout.ascales.copy(Device::cpu()); Tensor ascales = qout.ascales.copy(Device::cpu());
Tensor lora_act = qout.lora_act.copy(Device::cpu()); Tensor lora_act = qout.lora_act.copy(Device::cpu());
Tensor::synchronizeDevice(); Tensor::synchronizeDevice();
...@@ -109,5 +111,4 @@ public: ...@@ -109,5 +111,4 @@ public:
spdlog::debug("act = {}", dumpTensorINT4(act)); spdlog::debug("act = {}", dumpTensorINT4(act));
spdlog::debug("ascales = {}", dumpTensorBF16(ascales)); spdlog::debug("ascales = {}", dumpTensorBF16(ascales));
} }
}; };
...@@ -10,13 +10,14 @@ class QuantizedGEMM88 : public ModuleWrapper<GEMM_W8A8> { ...@@ -10,13 +10,14 @@ class QuantizedGEMM88 : public ModuleWrapper<GEMM_W8A8> {
public: public:
void init(int64_t in_features, int64_t out_features, bool bias, bool bf16, int8_t deviceId) { void init(int64_t in_features, int64_t out_features, bool bias, bool bf16, int8_t deviceId) {
spdlog::info("Initializing QuantizedGEMM88"); spdlog::info("Initializing QuantizedGEMM88");
size_t val = 0; size_t val = 0;
checkCUDA(cudaDeviceSetLimit(cudaLimitStackSize, 8192)); checkCUDA(cudaDeviceSetLimit(cudaLimitStackSize, 8192));
checkCUDA(cudaDeviceGetLimit(&val, cudaLimitStackSize)); checkCUDA(cudaDeviceGetLimit(&val, cudaLimitStackSize));
spdlog::debug("Stack={}", val); spdlog::debug("Stack={}", val);
net = std::make_unique<GEMM_W8A8>((int)in_features, (int)out_features, bias, bf16 ? Tensor::BF16 : Tensor::FP16, Device::cuda((int)deviceId)); net = std::make_unique<GEMM_W8A8>(
(int)in_features, (int)out_features, bias, bf16 ? Tensor::BF16 : Tensor::FP16, Device::cuda((int)deviceId));
} }
torch::Tensor forward(torch::Tensor x) { torch::Tensor forward(torch::Tensor x) {
...@@ -27,10 +28,10 @@ public: ...@@ -27,10 +28,10 @@ public:
x = x.contiguous(); x = x.contiguous();
Tensor result = net->forward(from_torch(x)); Tensor result = net->forward(from_torch(x));
torch::Tensor output = to_torch(result); torch::Tensor output = to_torch(result);
Tensor::synchronizeDevice(); Tensor::synchronizeDevice();
return output; return output;
} }
}; };
\ No newline at end of file
...@@ -18,7 +18,7 @@ public: ...@@ -18,7 +18,7 @@ public:
debugContext.reset(); debugContext.reset();
net.reset(); net.reset();
Tensor::synchronizeDevice(); Tensor::synchronizeDevice();
nunchaku::utils::trim_memory(); nunchaku::utils::trim_memory();
Tensor::synchronizeDevice(); Tensor::synchronizeDevice();
} }
...@@ -28,7 +28,7 @@ public: ...@@ -28,7 +28,7 @@ public:
CUDADeviceContext ctx(this->deviceId); CUDADeviceContext ctx(this->deviceId);
spdlog::info("{} weights from {}", partial ? "Loading partial" : "Loading", path); spdlog::info("{} weights from {}", partial ? "Loading partial" : "Loading", path);
std::shared_ptr<SafeTensors> provider = std::make_shared<SafeTensors>(path); std::shared_ptr<SafeTensors> provider = std::make_shared<SafeTensors>(path);
net->loadParams(*provider, partial); net->loadParams(*provider, partial);
Tensor::synchronizeDevice(); Tensor::synchronizeDevice();
...@@ -41,7 +41,7 @@ public: ...@@ -41,7 +41,7 @@ public:
CUDADeviceContext ctx(this->deviceId); CUDADeviceContext ctx(this->deviceId);
spdlog::info("{} weights from pytorch", partial ? "Loading partial" : "Loading"); spdlog::info("{} weights from pytorch", partial ? "Loading partial" : "Loading");
std::shared_ptr<TensorsProviderTorch> provider = std::make_shared<TensorsProviderTorch>(std::move(dict)); std::shared_ptr<TensorsProviderTorch> provider = std::make_shared<TensorsProviderTorch>(std::move(dict));
net->loadParams(*provider, partial); net->loadParams(*provider, partial);
Tensor::synchronizeDevice(); Tensor::synchronizeDevice();
...@@ -66,7 +66,7 @@ public: ...@@ -66,7 +66,7 @@ public:
result[key] = to_torch(value); result[key] = to_torch(value);
} }
} }
return result; return result;
} }
...@@ -82,4 +82,4 @@ protected: ...@@ -82,4 +82,4 @@ protected:
std::unique_ptr<DebugContext> debugContext; std::unique_ptr<DebugContext> debugContext;
int deviceId = -1; int deviceId = -1;
}; };
\ No newline at end of file
...@@ -7,175 +7,132 @@ ...@@ -7,175 +7,132 @@
namespace nunchaku::ops { namespace nunchaku::ops {
void gemm_w4a4( void gemm_w4a4(std::optional<torch::Tensor> act, // packed act [M, K / 2]
std::optional<torch::Tensor> act, // packed act [M, K / 2] std::optional<torch::Tensor> wgt, // packed act [N, K / 2]
std::optional<torch::Tensor> wgt, // packed act [N, K / 2] std::optional<torch::Tensor> out, // linear [M, N]
std::optional<torch::Tensor> out, // linear [M, N] std::optional<torch::Tensor> qout, // packed act [M, N / 2]
std::optional<torch::Tensor> qout, // packed act [M, N / 2] std::optional<torch::Tensor> ascales, // packed as [K / 64, M]
std::optional<torch::Tensor> ascales, // packed as [K / 64, M] std::optional<torch::Tensor> wscales, // packed ws [K / 64, N]
std::optional<torch::Tensor> wscales, // packed ws [K / 64, N] std::optional<torch::Tensor> oscales, // packed as [N / 64, M]
std::optional<torch::Tensor> oscales, // packed as [N / 64, M] std::optional<torch::Tensor> poolout, // linear [M / PoolSize, N]
std::optional<torch::Tensor> poolout, // linear [M / PoolSize, N] std::optional<torch::Tensor> lora_act_in, // packed lora_act [M, R]
std::optional<torch::Tensor> lora_act_in, // packed lora_act [M, R] std::optional<torch::Tensor> lora_up, // packed lora_wgt [N, R]
std::optional<torch::Tensor> lora_up, // packed lora_wgt [N, R] std::optional<torch::Tensor> lora_down, // packed lora_wgt [N, R]
std::optional<torch::Tensor> lora_down, // packed lora_wgt [N, R] std::optional<torch::Tensor> lora_act_out, // packed lora_act [M, R]
std::optional<torch::Tensor> lora_act_out, // packed lora_act [M, R] std::optional<torch::Tensor> norm_q, // linear [HEAD_DIM]
std::optional<torch::Tensor> norm_q, // linear [HEAD_DIM] std::optional<torch::Tensor> norm_k, // linear [HEAD_DIM]
std::optional<torch::Tensor> norm_k, // linear [HEAD_DIM] std::optional<torch::Tensor> rotary_emb, // linear [M, HEAD_DIM / 2, 2, 2]
std::optional<torch::Tensor> rotary_emb, // linear [M, HEAD_DIM / 2, 2, 2] std::optional<torch::Tensor> bias, // packed ws [N]
std::optional<torch::Tensor> bias, // packed ws [N] std::optional<torch::Tensor> smooth_factor, // packed ws [N], for quantization of the next layer
std::optional<torch::Tensor> smooth_factor, // packed ws [N], for quantization of the next layer std::optional<torch::Tensor> out_vk, // linear [B, num_heads, head_dim + 1, head_dim]
std::optional<torch::Tensor> out_vk, // linear [B, num_heads, head_dim + 1, head_dim] std::optional<torch::Tensor> out_linearattn, // linear [B, (M), N / 3]
std::optional<torch::Tensor> out_linearattn,// linear [B, (M), N / 3] bool act_unsigned,
bool act_unsigned, std::vector<float> lora_scales,
std::vector<float> lora_scales, bool fuse_silu,
bool fuse_silu, bool fp4,
bool fp4, float alpha,
float alpha, std::optional<torch::Tensor> wcscales,
std::optional<torch::Tensor> wcscales, std::optional<torch::Tensor> out_q, // packed attention [B, H, M, D]
std::optional<torch::Tensor> out_q, // packed attention [B, H, M, D] std::optional<torch::Tensor> out_k, // packed attention [B, H, M, D]
std::optional<torch::Tensor> out_k, // packed attention [B, H, M, D] std::optional<torch::Tensor> out_v, // packed attention [B, H, M, D]
std::optional<torch::Tensor> out_v, // packed attention [B, H, M, D] int attn_tokens) {
int attn_tokens spdlog::trace("running gemm_w4a4: ");
) {
spdlog::trace("running gemm_w4a4: ");
auto getTensor = [](std::optional<torch::Tensor> &t) { auto getTensor = [](std::optional<torch::Tensor> &t) {
Tensor ret = t.has_value() ? from_torch(t.value()) : Tensor{}; Tensor ret = t.has_value() ? from_torch(t.value()) : Tensor{};
if (ret.valid()) { if (ret.valid()) {
spdlog::trace(" {}", ret.shape.str()); spdlog::trace(" {}", ret.shape.str());
} else { } else {
spdlog::trace(" <invalid>"); spdlog::trace(" <invalid>");
} }
return ret; return ret;
}; };
nunchaku::kernels::gemm_w4a4( nunchaku::kernels::gemm_w4a4(getTensor(act),
getTensor(act ), getTensor(wgt),
getTensor(wgt ), getTensor(out),
getTensor(out ), getTensor(qout),
getTensor(qout ), getTensor(ascales),
getTensor(ascales ), getTensor(wscales),
getTensor(wscales ), getTensor(oscales),
getTensor(oscales ), getTensor(poolout),
getTensor(poolout ), getTensor(lora_act_in),
getTensor(lora_act_in ), getTensor(lora_up),
getTensor(lora_up ), getTensor(lora_down),
getTensor(lora_down ), getTensor(lora_act_out),
getTensor(lora_act_out ), getTensor(norm_q),
getTensor(norm_q ), getTensor(norm_k),
getTensor(norm_k ), getTensor(rotary_emb),
getTensor(rotary_emb ), getTensor(bias),
getTensor(bias ), getTensor(smooth_factor),
getTensor(smooth_factor), getTensor(out_vk),
getTensor(out_vk ), getTensor(out_linearattn),
getTensor(out_linearattn), act_unsigned,
act_unsigned, lora_scales,
lora_scales, fuse_silu,
fuse_silu, fp4,
fp4, alpha,
alpha, getTensor(wcscales),
getTensor(wcscales), getTensor(out_q),
getTensor(out_q), getTensor(out_k),
getTensor(out_k), getTensor(out_v),
getTensor(out_v), attn_tokens);
attn_tokens // Tensor::synchronizeDevice();
); }
// Tensor::synchronizeDevice();
}
void attention_fp16( void attention_fp16(torch::Tensor q, // packed [Batch, Head, TokensQ, HEAD_DIM]
torch::Tensor q, // packed [Batch, Head, TokensQ, HEAD_DIM] torch::Tensor k, // packed [Batch, Head, TokensKV, HEAD_DIM]
torch::Tensor k, // packed [Batch, Head, TokensKV, HEAD_DIM] torch::Tensor v, // packed [Batch, Head, TokensKV, HEAD_DIM]
torch::Tensor v, // packed [Batch, Head, TokensKV, HEAD_DIM] torch::Tensor o, // linear [Batch, TokensQ, Head * HEAD_DIM]
torch::Tensor o, // linear [Batch, TokensQ, Head * HEAD_DIM] float scale) {
float scale nunchaku::kernels::attention_fp16(from_torch(q), from_torch(k), from_torch(v), from_torch(o), scale);
) { }
nunchaku::kernels::attention_fp16(
from_torch(q),
from_torch(k),
from_torch(v),
from_torch(o),
scale
);
}
torch::Tensor gemv_awq( torch::Tensor gemv_awq(torch::Tensor _in_feats,
torch::Tensor _in_feats, torch::Tensor _kernel,
torch::Tensor _kernel, torch::Tensor _scaling_factors,
torch::Tensor _scaling_factors, torch::Tensor _zeros,
torch::Tensor _zeros, int64_t m,
int64_t m, int64_t n,
int64_t n, int64_t k,
int64_t k, int64_t group_size) {
int64_t group_size) Tensor result = ::gemv_awq(from_torch(_in_feats.contiguous()),
{ from_torch(_kernel.contiguous()),
Tensor result = ::gemv_awq( from_torch(_scaling_factors.contiguous()),
from_torch(_in_feats.contiguous()), from_torch(_zeros.contiguous()),
from_torch(_kernel.contiguous()), (int)m,
from_torch(_scaling_factors.contiguous()), (int)n,
from_torch(_zeros.contiguous()), (int)k,
(int)m, (int)group_size);
(int)n,
(int)k,
(int)group_size
);
torch::Tensor output = to_torch(result); torch::Tensor output = to_torch(result);
// Tensor::synchronizeDevice(); // Tensor::synchronizeDevice();
return output; return output;
} }
torch::Tensor gemm_awq( torch::Tensor
torch::Tensor _in_feats, gemm_awq(torch::Tensor _in_feats, torch::Tensor _kernel, torch::Tensor _scaling_factors, torch::Tensor _zeros) {
torch::Tensor _kernel, Tensor result = ::awq_gemm_forward_cuda(from_torch(_in_feats.contiguous()),
torch::Tensor _scaling_factors, from_torch(_kernel.contiguous()),
torch::Tensor _zeros) from_torch(_scaling_factors.contiguous()),
{ from_torch(_zeros.contiguous()));
Tensor result = ::awq_gemm_forward_cuda(
from_torch(_in_feats.contiguous()),
from_torch(_kernel.contiguous()),
from_torch(_scaling_factors.contiguous()),
from_torch(_zeros.contiguous())
);
// TODO: allocate output in torch and use from_torch instead (to_torch needs an extra copy) // TODO: allocate output in torch and use from_torch instead (to_torch needs an extra copy)
torch::Tensor output = to_torch(result); torch::Tensor output = to_torch(result);
// Tensor::synchronizeDevice(); // Tensor::synchronizeDevice();
return output; return output;
} }
void test_rmsnorm_rope( void test_rmsnorm_rope(
torch::Tensor input, torch::Tensor input, torch::Tensor output, torch::Tensor norm_q, torch::Tensor norm_k, torch::Tensor rotary_emb) {
torch::Tensor output, nunchaku::kernels::test_rmsnorm_rope(
torch::Tensor norm_q, from_torch(input), from_torch(output), from_torch(norm_q), from_torch(norm_k), from_torch(rotary_emb));
torch::Tensor norm_k, }
torch::Tensor rotary_emb)
{
nunchaku::kernels::test_rmsnorm_rope(
from_torch(input),
from_torch(output),
from_torch(norm_q),
from_torch(norm_k),
from_torch(rotary_emb)
);
}
void test_pack_qkv( void test_pack_qkv(torch::Tensor input, torch::Tensor out_q, torch::Tensor out_k, torch::Tensor out_v, int numTokens) {
torch::Tensor input, nunchaku::kernels::test_pack_qkv(
torch::Tensor out_q, from_torch(input), from_torch(out_q), from_torch(out_k), from_torch(out_v), numTokens);
torch::Tensor out_k, }
torch::Tensor out_v,
int numTokens) }; // namespace nunchaku::ops
{
nunchaku::kernels::test_pack_qkv(
from_torch(input),
from_torch(out_q),
from_torch(out_k),
from_torch(out_v),
numTokens
);
}
};
\ No newline at end of file
...@@ -5,80 +5,75 @@ ...@@ -5,80 +5,75 @@
#include "ops.h" #include "ops.h"
#include "utils.h" #include "utils.h"
#include <torch/extension.h> #include <torch/extension.h>
#include "interop/torch.h"
#include <pybind11/pybind11.h> #include <pybind11/pybind11.h>
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
py::class_<QuantizedFluxModel>(m, "QuantizedFluxModel") py::class_<QuantizedFluxModel>(m, "QuantizedFluxModel")
.def(py::init<>()) .def(py::init<>())
.def("init", &QuantizedFluxModel::init, .def("init",
py::arg("use_fp4"), &QuantizedFluxModel::init,
py::arg("offload"), py::arg("use_fp4"),
py::arg("bf16"), py::arg("offload"),
py::arg("deviceId") py::arg("bf16"),
) py::arg("deviceId"))
.def("set_residual_callback",
[](QuantizedFluxModel &self, pybind11::object call_back) {
if (call_back.is_none()) {
self.set_residual_callback(pybind11::function());
} else {
self.set_residual_callback(call_back);
}
})
.def("reset", &QuantizedFluxModel::reset) .def("reset", &QuantizedFluxModel::reset)
.def("load", &QuantizedFluxModel::load, .def("load", &QuantizedFluxModel::load, py::arg("path"), py::arg("partial") = false)
py::arg("path"), .def("loadDict", &QuantizedFluxModel::loadDict, py::arg("dict"), py::arg("partial") = false)
py::arg("partial") = false .def("forward",
) &QuantizedFluxModel::forward,
.def("loadDict", &QuantizedFluxModel::loadDict, py::arg("hidden_states"),
py::arg("dict"), py::arg("encoder_hidden_states"),
py::arg("partial") = false py::arg("temb"),
) py::arg("rotary_emb_img"),
.def("forward", &QuantizedFluxModel::forward, py::arg("rotary_emb_context"),
py::arg("hidden_states"), py::arg("rotary_emb_single"),
py::arg("encoder_hidden_states"), py::arg("controlnet_block_samples") = py::none(),
py::arg("temb"), py::arg("controlnet_single_block_samples") = py::none(),
py::arg("rotary_emb_img"), py::arg("skip_first_layer") = false)
py::arg("rotary_emb_context"), .def("forward_layer",
py::arg("rotary_emb_single"), &QuantizedFluxModel::forward_layer,
py::arg("controlnet_block_samples") = py::none(), py::arg("idx"),
py::arg("controlnet_single_block_samples") = py::none(), py::arg("hidden_states"),
py::arg("skip_first_layer") = false py::arg("encoder_hidden_states"),
) py::arg("temb"),
.def("forward_layer", &QuantizedFluxModel::forward_layer, py::arg("rotary_emb_img"),
py::arg("idx"), py::arg("rotary_emb_context"),
py::arg("hidden_states"), py::arg("controlnet_block_samples") = py::none(),
py::arg("encoder_hidden_states"), py::arg("controlnet_single_block_samples") = py::none())
py::arg("temb"),
py::arg("rotary_emb_img"),
py::arg("rotary_emb_context"),
py::arg("controlnet_block_samples") = py::none(),
py::arg("controlnet_single_block_samples") = py::none()
)
.def("forward_single_layer", &QuantizedFluxModel::forward_single_layer) .def("forward_single_layer", &QuantizedFluxModel::forward_single_layer)
.def("norm_one_forward", &QuantizedFluxModel::norm_one_forward)
.def("startDebug", &QuantizedFluxModel::startDebug) .def("startDebug", &QuantizedFluxModel::startDebug)
.def("stopDebug", &QuantizedFluxModel::stopDebug) .def("stopDebug", &QuantizedFluxModel::stopDebug)
.def("getDebugResults", &QuantizedFluxModel::getDebugResults) .def("getDebugResults", &QuantizedFluxModel::getDebugResults)
.def("setLoraScale", &QuantizedFluxModel::setLoraScale) .def("setLoraScale", &QuantizedFluxModel::setLoraScale)
.def("setAttentionImpl", &QuantizedFluxModel::setAttentionImpl) .def("setAttentionImpl", &QuantizedFluxModel::setAttentionImpl)
.def("isBF16", &QuantizedFluxModel::isBF16) .def("isBF16", &QuantizedFluxModel::isBF16);
;
py::class_<QuantizedSanaModel>(m, "QuantizedSanaModel") py::class_<QuantizedSanaModel>(m, "QuantizedSanaModel")
.def(py::init<>()) .def(py::init<>())
.def("init", &QuantizedSanaModel::init, .def("init",
py::arg("config"), &QuantizedSanaModel::init,
py::arg("pag_layers"), py::arg("config"),
py::arg("use_fp4"), py::arg("pag_layers"),
py::arg("bf16"), py::arg("use_fp4"),
py::arg("deviceId") py::arg("bf16"),
) py::arg("deviceId"))
.def("reset", &QuantizedSanaModel::reset) .def("reset", &QuantizedSanaModel::reset)
.def("load", &QuantizedSanaModel::load, .def("load", &QuantizedSanaModel::load, py::arg("path"), py::arg("partial") = false)
py::arg("path"), .def("loadDict", &QuantizedSanaModel::loadDict, py::arg("dict"), py::arg("partial") = false)
py::arg("partial") = false
)
.def("loadDict", &QuantizedSanaModel::loadDict,
py::arg("dict"),
py::arg("partial") = false
)
.def("forward", &QuantizedSanaModel::forward) .def("forward", &QuantizedSanaModel::forward)
.def("forward_layer", &QuantizedSanaModel::forward_layer) .def("forward_layer", &QuantizedSanaModel::forward_layer)
.def("startDebug", &QuantizedSanaModel::startDebug) .def("startDebug", &QuantizedSanaModel::startDebug)
.def("stopDebug", &QuantizedSanaModel::stopDebug) .def("stopDebug", &QuantizedSanaModel::stopDebug)
.def("getDebugResults", &QuantizedSanaModel::getDebugResults) .def("getDebugResults", &QuantizedSanaModel::getDebugResults);
;
py::class_<QuantizedGEMM>(m, "QuantizedGEMM") py::class_<QuantizedGEMM>(m, "QuantizedGEMM")
.def(py::init<>()) .def(py::init<>())
.def("init", &QuantizedGEMM::init) .def("init", &QuantizedGEMM::init)
...@@ -88,8 +83,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { ...@@ -88,8 +83,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
.def("quantize", &QuantizedGEMM::quantize) .def("quantize", &QuantizedGEMM::quantize)
.def("startDebug", &QuantizedGEMM::startDebug) .def("startDebug", &QuantizedGEMM::startDebug)
.def("stopDebug", &QuantizedGEMM::stopDebug) .def("stopDebug", &QuantizedGEMM::stopDebug)
.def("getDebugResults", &QuantizedGEMM::getDebugResults) .def("getDebugResults", &QuantizedGEMM::getDebugResults);
; py::class_<Tensor>(m, "Tensor");
py::class_<QuantizedGEMM88>(m, "QuantizedGEMM88") py::class_<QuantizedGEMM88>(m, "QuantizedGEMM88")
.def(py::init<>()) .def(py::init<>())
.def("init", &QuantizedGEMM88::init) .def("init", &QuantizedGEMM88::init)
...@@ -98,8 +93,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { ...@@ -98,8 +93,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
.def("forward", &QuantizedGEMM88::forward) .def("forward", &QuantizedGEMM88::forward)
.def("startDebug", &QuantizedGEMM88::startDebug) .def("startDebug", &QuantizedGEMM88::startDebug)
.def("stopDebug", &QuantizedGEMM88::stopDebug) .def("stopDebug", &QuantizedGEMM88::stopDebug)
.def("getDebugResults", &QuantizedGEMM88::getDebugResults) .def("getDebugResults", &QuantizedGEMM88::getDebugResults);
;
m.def_submodule("ops") m.def_submodule("ops")
.def("gemm_w4a4", nunchaku::ops::gemm_w4a4) .def("gemm_w4a4", nunchaku::ops::gemm_w4a4)
...@@ -108,16 +102,12 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { ...@@ -108,16 +102,12 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
.def("gemv_awq", nunchaku::ops::gemv_awq) .def("gemv_awq", nunchaku::ops::gemv_awq)
.def("test_rmsnorm_rope", nunchaku::ops::test_rmsnorm_rope) .def("test_rmsnorm_rope", nunchaku::ops::test_rmsnorm_rope)
.def("test_pack_qkv", nunchaku::ops::test_pack_qkv) .def("test_pack_qkv", nunchaku::ops::test_pack_qkv);
;
m.def_submodule("utils") m.def_submodule("utils")
.def("set_log_level", [](const std::string &level) { .def("set_log_level", [](const std::string &level) { spdlog::set_level(spdlog::level::from_str(level)); })
spdlog::set_level(spdlog::level::from_str(level));
})
.def("set_cuda_stack_limit", nunchaku::utils::set_cuda_stack_limit) .def("set_cuda_stack_limit", nunchaku::utils::set_cuda_stack_limit)
.def("disable_memory_auto_release", nunchaku::utils::disable_memory_auto_release) .def("disable_memory_auto_release", nunchaku::utils::disable_memory_auto_release)
.def("trim_memory", nunchaku::utils::trim_memory) .def("trim_memory", nunchaku::utils::trim_memory)
.def("set_faster_i2f_mode", nunchaku::utils::set_faster_i2f_mode) .def("set_faster_i2f_mode", nunchaku::utils::set_faster_i2f_mode);
;
} }
...@@ -11,13 +11,13 @@ public: ...@@ -11,13 +11,13 @@ public:
void init(pybind11::dict config, std::vector<int> pag_layers, bool use_fp4, bool bf16, int8_t deviceId) { void init(pybind11::dict config, std::vector<int> pag_layers, bool use_fp4, bool bf16, int8_t deviceId) {
spdlog::info("Initializing QuantizedSanaModel on device {}", deviceId); spdlog::info("Initializing QuantizedSanaModel on device {}", deviceId);
SanaConfig cfg{ SanaConfig cfg{
.num_layers = config["num_layers"].cast<int>(), .num_layers = config["num_layers"].cast<int>(),
.num_attention_heads = config["num_attention_heads"].cast<int>(), .num_attention_heads = config["num_attention_heads"].cast<int>(),
.attention_head_dim = config["attention_head_dim"].cast<int>(), .attention_head_dim = config["attention_head_dim"].cast<int>(),
.num_cross_attention_heads = config["num_cross_attention_heads"].cast<int>(), .num_cross_attention_heads = config["num_cross_attention_heads"].cast<int>(),
.expand_ratio = config["mlp_ratio"].cast<double>(), .expand_ratio = config["mlp_ratio"].cast<double>(),
.pag_layers = pag_layers, .pag_layers = pag_layers,
.use_fp4 = use_fp4, .use_fp4 = use_fp4,
}; };
ModuleWrapper::init(deviceId); ModuleWrapper::init(deviceId);
...@@ -25,39 +25,37 @@ public: ...@@ -25,39 +25,37 @@ public:
net = std::make_unique<SanaModel>(cfg, bf16 ? Tensor::BF16 : Tensor::FP16, Device::cuda((int)deviceId)); net = std::make_unique<SanaModel>(cfg, bf16 ? Tensor::BF16 : Tensor::FP16, Device::cuda((int)deviceId));
} }
torch::Tensor forward( torch::Tensor forward(torch::Tensor hidden_states,
torch::Tensor hidden_states, torch::Tensor encoder_hidden_states,
torch::Tensor encoder_hidden_states, torch::Tensor timestep,
torch::Tensor timestep, torch::Tensor cu_seqlens_img,
torch::Tensor cu_seqlens_img, torch::Tensor cu_seqlens_txt,
torch::Tensor cu_seqlens_txt, int H,
int H, int W,
int W, bool pag,
bool pag, bool cfg,
bool cfg, bool skip_first_layer = false) {
bool skip_first_layer = false)
{
checkModel(); checkModel();
CUDADeviceContext ctx(deviceId); CUDADeviceContext ctx(deviceId);
spdlog::debug("QuantizedSanaModel forward"); spdlog::debug("QuantizedSanaModel forward");
hidden_states = hidden_states.contiguous(); hidden_states = hidden_states.contiguous();
encoder_hidden_states = encoder_hidden_states.contiguous(); encoder_hidden_states = encoder_hidden_states.contiguous();
timestep = timestep.contiguous(); timestep = timestep.contiguous();
cu_seqlens_img = cu_seqlens_img.contiguous(); cu_seqlens_img = cu_seqlens_img.contiguous();
cu_seqlens_txt = cu_seqlens_txt.contiguous(); cu_seqlens_txt = cu_seqlens_txt.contiguous();
Tensor result = net->forward( Tensor result = net->forward(from_torch(hidden_states),
from_torch(hidden_states), from_torch(encoder_hidden_states),
from_torch(encoder_hidden_states), from_torch(timestep),
from_torch(timestep), from_torch(cu_seqlens_img),
from_torch(cu_seqlens_img), from_torch(cu_seqlens_txt),
from_torch(cu_seqlens_txt), H,
H, W, W,
pag, cfg, pag,
skip_first_layer cfg,
); skip_first_layer);
torch::Tensor output = to_torch(result); torch::Tensor output = to_torch(result);
// Tensor::synchronizeDevice(); // Tensor::synchronizeDevice();
...@@ -65,42 +63,40 @@ public: ...@@ -65,42 +63,40 @@ public:
return output; return output;
} }
torch::Tensor forward_layer( torch::Tensor forward_layer(int64_t idx,
int64_t idx, torch::Tensor hidden_states,
torch::Tensor hidden_states, torch::Tensor encoder_hidden_states,
torch::Tensor encoder_hidden_states, torch::Tensor timestep,
torch::Tensor timestep, torch::Tensor cu_seqlens_img,
torch::Tensor cu_seqlens_img, torch::Tensor cu_seqlens_txt,
torch::Tensor cu_seqlens_txt, int H,
int H, int W,
int W, bool pag,
bool pag, bool cfg) {
bool cfg)
{
checkModel(); checkModel();
CUDADeviceContext ctx(deviceId); CUDADeviceContext ctx(deviceId);
spdlog::debug("QuantizedSanaModel forward_layer {}", idx); spdlog::debug("QuantizedSanaModel forward_layer {}", idx);
hidden_states = hidden_states.contiguous(); hidden_states = hidden_states.contiguous();
encoder_hidden_states = encoder_hidden_states.contiguous(); encoder_hidden_states = encoder_hidden_states.contiguous();
timestep = timestep.contiguous(); timestep = timestep.contiguous();
cu_seqlens_img = cu_seqlens_img.contiguous(); cu_seqlens_img = cu_seqlens_img.contiguous();
cu_seqlens_txt = cu_seqlens_txt.contiguous(); cu_seqlens_txt = cu_seqlens_txt.contiguous();
Tensor result = net->transformer_blocks.at(idx)->forward( Tensor result = net->transformer_blocks.at(idx)->forward(from_torch(hidden_states),
from_torch(hidden_states), from_torch(encoder_hidden_states),
from_torch(encoder_hidden_states), from_torch(timestep),
from_torch(timestep), from_torch(cu_seqlens_img),
from_torch(cu_seqlens_img), from_torch(cu_seqlens_txt),
from_torch(cu_seqlens_txt), H,
H, W, W,
pag, cfg pag,
); cfg);
torch::Tensor output = to_torch(result); torch::Tensor output = to_torch(result);
// Tensor::synchronizeDevice(); // Tensor::synchronizeDevice();
return output; return output;
} }
}; };
\ No newline at end of file
...@@ -6,34 +6,34 @@ ...@@ -6,34 +6,34 @@
namespace nunchaku::utils { namespace nunchaku::utils {
void set_cuda_stack_limit(int64_t newval) { void set_cuda_stack_limit(int64_t newval) {
size_t val = 0; size_t val = 0;
checkCUDA(cudaDeviceSetLimit(cudaLimitStackSize, (size_t)newval)); checkCUDA(cudaDeviceSetLimit(cudaLimitStackSize, (size_t)newval));
checkCUDA(cudaDeviceGetLimit(&val, cudaLimitStackSize)); checkCUDA(cudaDeviceGetLimit(&val, cudaLimitStackSize));
spdlog::debug("Stack={}", val); spdlog::debug("Stack={}", val);
} }
void disable_memory_auto_release() { void disable_memory_auto_release() {
int device; int device;
checkCUDA(cudaGetDevice(&device)); checkCUDA(cudaGetDevice(&device));
cudaMemPool_t mempool; cudaMemPool_t mempool;
checkCUDA(cudaDeviceGetDefaultMemPool(&mempool, device)); checkCUDA(cudaDeviceGetDefaultMemPool(&mempool, device));
uint64_t threshold = UINT64_MAX; uint64_t threshold = UINT64_MAX;
checkCUDA(cudaMemPoolSetAttribute(mempool, cudaMemPoolAttrReleaseThreshold, &threshold)); checkCUDA(cudaMemPoolSetAttribute(mempool, cudaMemPoolAttrReleaseThreshold, &threshold));
} }
void trim_memory() { void trim_memory() {
int device; int device;
checkCUDA(cudaGetDevice(&device)); checkCUDA(cudaGetDevice(&device));
cudaMemPool_t mempool; cudaMemPool_t mempool;
checkCUDA(cudaDeviceGetDefaultMemPool(&mempool, device)); checkCUDA(cudaDeviceGetDefaultMemPool(&mempool, device));
size_t bytesToKeep = 0; size_t bytesToKeep = 0;
checkCUDA(cudaMemPoolTrimTo(mempool, bytesToKeep)); checkCUDA(cudaMemPoolTrimTo(mempool, bytesToKeep));
} }
void set_faster_i2f_mode(std::string mode) { void set_faster_i2f_mode(std::string mode) {
spdlog::info("Set fasteri2f mode to {}", mode); spdlog::info("Set fasteri2f mode to {}", mode);
kernels::set_faster_i2f_mode(mode); kernels::set_faster_i2f_mode(mode);
} }
}; }; // namespace nunchaku::utils
\ No newline at end of file
from .diffusers_converter import to_diffusers from .diffusers_converter import to_diffusers
from .nunchaku_converter import convert_to_nunchaku_flux_lowrank_dict, to_nunchaku from .nunchaku_converter import convert_to_nunchaku_flux_lowrank_dict, to_nunchaku
from .utils import is_nunchaku_format from .utils import is_nunchaku_format
__all__ = ["to_diffusers", "to_nunchaku", "convert_to_nunchaku_flux_lowrank_dict", "is_nunchaku_format"]
...@@ -7,10 +7,10 @@ import torch ...@@ -7,10 +7,10 @@ import torch
from safetensors.torch import save_file from safetensors.torch import save_file
from tqdm import tqdm from tqdm import tqdm
from ...utils import filter_state_dict, load_state_dict_in_safetensors
from .diffusers_converter import to_diffusers from .diffusers_converter import to_diffusers
from .packer import NunchakuWeightPacker from .packer import NunchakuWeightPacker
from .utils import is_nunchaku_format, pad from .utils import is_nunchaku_format, pad
from ...utils import filter_state_dict, load_state_dict_in_safetensors
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
......
# Copy the packer from https://github.com/mit-han-lab/deepcompressor/ # Copy the packer from https://github.com/mit-han-lab/deepcompressor/
import torch import torch
from .utils import pad
from ...utils import ceil_divide from ...utils import ceil_divide
from .utils import pad
class MmaWeightPackerBase: class MmaWeightPackerBase:
......
from .text_encoders.t5_encoder import NunchakuT5EncoderModel from .text_encoders.t5_encoder import NunchakuT5EncoderModel
from .transformers import NunchakuFluxTransformer2dModel, NunchakuSanaTransformer2DModel from .transformers import NunchakuFluxTransformer2dModel, NunchakuSanaTransformer2DModel
__all__ = ["NunchakuFluxTransformer2dModel", "NunchakuSanaTransformer2DModel", "NunchakuT5EncoderModel"]
# Adapted from https://github.com/ToTheBeginning/PuLID
import math
import torch
from torch import nn
# FFN
def FeedForward(dim, mult=4):
inner_dim = int(dim * mult)
return nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, inner_dim, bias=False),
nn.GELU(),
nn.Linear(inner_dim, dim, bias=False),
)
def reshape_tensor(x, heads):
bs, length, width = x.shape
# (bs, length, width) --> (bs, length, n_heads, dim_per_head)
x = x.view(bs, length, heads, -1)
# (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head)
x = x.transpose(1, 2)
# (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head)
x = x.reshape(bs, heads, length, -1)
return x
class PerceiverAttentionCA(nn.Module):
def __init__(self, *, dim=3072, dim_head=128, heads=16, kv_dim=2048):
super().__init__()
self.scale = dim_head**-0.5
self.dim_head = dim_head
self.heads = heads
inner_dim = dim_head * heads
self.norm1 = nn.LayerNorm(dim if kv_dim is None else kv_dim)
self.norm2 = nn.LayerNorm(dim)
self.to_q = nn.Linear(dim, inner_dim, bias=False)
self.to_kv = nn.Linear(dim if kv_dim is None else kv_dim, inner_dim * 2, bias=False)
self.to_out = nn.Linear(inner_dim, dim, bias=False)
def forward(self, x, latents):
"""
Args:
x (torch.Tensor): image features
shape (b, n1, D)
latent (torch.Tensor): latent features
shape (b, n2, D)
"""
x = self.norm1(x)
latents = self.norm2(latents)
b, seq_len, _ = latents.shape
q = self.to_q(latents)
k, v = self.to_kv(x).chunk(2, dim=-1)
q = reshape_tensor(q, self.heads)
k = reshape_tensor(k, self.heads)
v = reshape_tensor(v, self.heads)
# attention
scale = 1 / math.sqrt(math.sqrt(self.dim_head))
weight = (q * scale) @ (k * scale).transpose(-2, -1) # More stable with f16 than dividing afterwards
weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
out = weight @ v
out = out.permute(0, 2, 1, 3).reshape(b, seq_len, -1)
return self.to_out(out)
class PerceiverAttention(nn.Module):
def __init__(self, *, dim, dim_head=64, heads=8, kv_dim=None):
super().__init__()
self.scale = dim_head**-0.5
self.dim_head = dim_head
self.heads = heads
inner_dim = dim_head * heads
self.norm1 = nn.LayerNorm(dim if kv_dim is None else kv_dim)
self.norm2 = nn.LayerNorm(dim)
self.to_q = nn.Linear(dim, inner_dim, bias=False)
self.to_kv = nn.Linear(dim if kv_dim is None else kv_dim, inner_dim * 2, bias=False)
self.to_out = nn.Linear(inner_dim, dim, bias=False)
def forward(self, x, latents):
"""
Args:
x (torch.Tensor): image features
shape (b, n1, D)
latent (torch.Tensor): latent features
shape (b, n2, D)
"""
x = self.norm1(x)
latents = self.norm2(latents)
b, seq_len, _ = latents.shape
q = self.to_q(latents)
kv_input = torch.cat((x, latents), dim=-2)
k, v = self.to_kv(kv_input).chunk(2, dim=-1)
q = reshape_tensor(q, self.heads)
k = reshape_tensor(k, self.heads)
v = reshape_tensor(v, self.heads)
# attention
scale = 1 / math.sqrt(math.sqrt(self.dim_head))
weight = (q * scale) @ (k * scale).transpose(-2, -1) # More stable with f16 than dividing afterwards
weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
out = weight @ v
out = out.permute(0, 2, 1, 3).reshape(b, seq_len, -1)
return self.to_out(out)
class IDFormer(nn.Module):
"""
- perceiver resampler like arch (compared with previous MLP-like arch)
- we concat id embedding (generated by arcface) and query tokens as latents
- latents will attend each other and interact with vit features through cross-attention
- vit features are multi-scaled and inserted into IDFormer in order, currently, each scale corresponds to two
IDFormer layers
"""
def __init__(
self,
dim=1024,
depth=10,
dim_head=64,
heads=16,
num_id_token=5,
num_queries=32,
output_dim=2048,
ff_mult=4,
):
super().__init__()
self.num_id_token = num_id_token
self.dim = dim
self.num_queries = num_queries
assert depth % 5 == 0
self.depth = depth // 5
scale = dim**-0.5
self.latents = nn.Parameter(torch.randn(1, num_queries, dim) * scale)
self.proj_out = nn.Parameter(scale * torch.randn(dim, output_dim))
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(
nn.ModuleList(
[
PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads),
FeedForward(dim=dim, mult=ff_mult),
]
)
)
for i in range(5):
setattr(
self,
f"mapping_{i}",
nn.Sequential(
nn.Linear(1024, 1024),
nn.LayerNorm(1024),
nn.LeakyReLU(),
nn.Linear(1024, 1024),
nn.LayerNorm(1024),
nn.LeakyReLU(),
nn.Linear(1024, dim),
),
)
self.id_embedding_mapping = nn.Sequential(
nn.Linear(1280, 1024),
nn.LayerNorm(1024),
nn.LeakyReLU(),
nn.Linear(1024, 1024),
nn.LayerNorm(1024),
nn.LeakyReLU(),
nn.Linear(1024, dim * num_id_token),
)
def forward(self, x, y):
latents = self.latents.repeat(x.size(0), 1, 1)
num_duotu = x.shape[1] if x.ndim == 3 else 1
x = self.id_embedding_mapping(x)
x = x.reshape(-1, self.num_id_token * num_duotu, self.dim)
latents = torch.cat((latents, x), dim=1)
for i in range(5):
vit_feature = getattr(self, f"mapping_{i}")(y[i])
ctx_feature = torch.cat((x, vit_feature), dim=1)
for attn, ff in self.layers[i * self.depth : (i + 1) * self.depth]:
latents = attn(ctx_feature, latents) + latents
latents = ff(latents) + latents
latents = latents[:, : self.num_queries]
latents = latents @ self.proj_out
return latents
from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD
from .factory import create_model_and_transforms
__all__ = ["create_model_and_transforms", "OPENAI_DATASET_MEAN", "OPENAI_DATASET_STD"]
OPENAI_DATASET_MEAN = (0.48145466, 0.4578275, 0.40821073)
OPENAI_DATASET_STD = (0.26862954, 0.26130258, 0.27577711)
This diff is collapsed.
import json
import logging
import os
import re
from copy import deepcopy
from pathlib import Path
from typing import Optional, Tuple, Union
import torch
from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD
from .model import CLIP, CustomCLIP, convert_to_custom_text_state_dict, get_cast_dtype
from .pretrained import download_pretrained, get_pretrained_cfg, list_pretrained_tags_by_model
from .transform import image_transform
from .utils import resize_clip_pos_embed, resize_eva_pos_embed, resize_evaclip_pos_embed, resize_visual_pos_embed
_MODEL_CONFIG_PATHS = [Path(__file__).parent / "model_configs/"]
_MODEL_CONFIGS = {} # directory (model_name: config) of model architecture configs
def _natural_key(string_):
return [int(s) if s.isdigit() else s for s in re.split(r"(\d+)", string_.lower())]
def _rescan_model_configs():
global _MODEL_CONFIGS
config_ext = (".json",)
config_files = []
for config_path in _MODEL_CONFIG_PATHS:
if config_path.is_file() and config_path.suffix in config_ext:
config_files.append(config_path)
elif config_path.is_dir():
for ext in config_ext:
config_files.extend(config_path.glob(f"*{ext}"))
for cf in config_files:
with open(cf, "r", encoding="utf8") as f:
model_cfg = json.load(f)
if all(a in model_cfg for a in ("embed_dim", "vision_cfg", "text_cfg")):
_MODEL_CONFIGS[cf.stem] = model_cfg
_MODEL_CONFIGS = dict(sorted(_MODEL_CONFIGS.items(), key=lambda x: _natural_key(x[0])))
_rescan_model_configs() # initial populate of model config registry
def list_models():
"""enumerate available model architectures based on config files"""
return list(_MODEL_CONFIGS.keys())
def get_model_config(model_name):
if model_name in _MODEL_CONFIGS:
return deepcopy(_MODEL_CONFIGS[model_name])
else:
return None
# loading openai CLIP weights when is_openai=True for training
def load_state_dict(
checkpoint_path: str,
map_location: str = "cpu",
model_key: str = "model|module|state_dict",
is_openai: bool = False,
skip_list: list = [],
):
if is_openai:
model = torch.jit.load(checkpoint_path, map_location="cpu").eval()
state_dict = model.state_dict()
for key in ["input_resolution", "context_length", "vocab_size"]:
state_dict.pop(key, None)
else:
checkpoint = torch.load(checkpoint_path, map_location=map_location)
for mk in model_key.split("|"):
if isinstance(checkpoint, dict) and mk in checkpoint:
state_dict = checkpoint[mk]
break
else:
state_dict = checkpoint
if next(iter(state_dict.items()))[0].startswith("module"):
state_dict = {k[7:]: v for k, v in state_dict.items()}
for k in skip_list:
if k in list(state_dict.keys()):
logging.info(f"Removing key {k} from pretrained checkpoint")
del state_dict[k]
if os.getenv("RoPE") == "1":
for k in list(state_dict.keys()):
if "freqs_cos" in k or "freqs_sin" in k:
del state_dict[k]
return state_dict
def load_checkpoint(model, checkpoint_path, model_key="model|module|state_dict", strict=True):
state_dict = load_state_dict(checkpoint_path, model_key=model_key, is_openai=False)
# detect old format and make compatible with new format
if "positional_embedding" in state_dict and not hasattr(model, "positional_embedding"):
state_dict = convert_to_custom_text_state_dict(state_dict)
if "text.logit_scale" in state_dict and hasattr(model, "logit_scale"):
state_dict["logit_scale"] = state_dict["text.logit_scale"]
del state_dict["text.logit_scale"]
# resize_clip_pos_embed for CLIP and open CLIP
if "visual.positional_embedding" in state_dict:
resize_clip_pos_embed(state_dict, model)
# specified to eva_vit_model
elif "visual.pos_embed" in state_dict:
resize_evaclip_pos_embed(state_dict, model)
# resize_clip_pos_embed(state_dict, model)
incompatible_keys = model.load_state_dict(state_dict, strict=strict)
# logging.info(f"incompatible_keys.missing_keys: {incompatible_keys.missing_keys}")
return incompatible_keys
def load_clip_visual_state_dict(
checkpoint_path: str, map_location: str = "cpu", is_openai: bool = False, skip_list: list = []
):
state_dict = load_state_dict(checkpoint_path, map_location=map_location, is_openai=is_openai, skip_list=skip_list)
for k in list(state_dict.keys()):
if not k.startswith("visual."):
del state_dict[k]
for k in list(state_dict.keys()):
if k.startswith("visual."):
new_k = k[7:]
state_dict[new_k] = state_dict[k]
del state_dict[k]
return state_dict
def load_clip_text_state_dict(
checkpoint_path: str, map_location: str = "cpu", is_openai: bool = False, skip_list: list = []
):
state_dict = load_state_dict(checkpoint_path, map_location=map_location, is_openai=is_openai, skip_list=skip_list)
for k in list(state_dict.keys()):
if k.startswith("visual."):
del state_dict[k]
return state_dict
def get_pretrained_tag(pretrained_model):
pretrained_model = pretrained_model.lower()
if "laion" in pretrained_model or "open_clip" in pretrained_model:
return "open_clip"
elif "openai" in pretrained_model:
return "clip"
elif "eva" in pretrained_model and "clip" in pretrained_model:
return "eva_clip"
else:
return "other"
def load_pretrained_checkpoint(
model,
visual_checkpoint_path,
text_checkpoint_path,
strict=True,
visual_model=None,
text_model=None,
model_key="model|module|state_dict",
skip_list=[],
):
visual_tag = get_pretrained_tag(visual_model)
text_tag = get_pretrained_tag(text_model)
logging.info(f"num of model state_dict keys: {len(model.state_dict().keys())}")
visual_incompatible_keys, text_incompatible_keys = None, None
if visual_checkpoint_path:
if visual_tag == "eva_clip" or visual_tag == "open_clip":
visual_state_dict = load_clip_visual_state_dict(
visual_checkpoint_path, is_openai=False, skip_list=skip_list
)
elif visual_tag == "clip":
visual_state_dict = load_clip_visual_state_dict(visual_checkpoint_path, is_openai=True, skip_list=skip_list)
else:
visual_state_dict = load_state_dict(
visual_checkpoint_path, model_key=model_key, is_openai=False, skip_list=skip_list
)
# resize_clip_pos_embed for CLIP and open CLIP
if "positional_embedding" in visual_state_dict:
resize_visual_pos_embed(visual_state_dict, model)
# specified to EVA model
elif "pos_embed" in visual_state_dict:
resize_eva_pos_embed(visual_state_dict, model)
visual_incompatible_keys = model.visual.load_state_dict(visual_state_dict, strict=strict)
logging.info(f"num of loaded visual_state_dict keys: {len(visual_state_dict.keys())}")
logging.info(f"visual_incompatible_keys.missing_keys: {visual_incompatible_keys.missing_keys}")
if text_checkpoint_path:
if text_tag == "eva_clip" or text_tag == "open_clip":
text_state_dict = load_clip_text_state_dict(text_checkpoint_path, is_openai=False, skip_list=skip_list)
elif text_tag == "clip":
text_state_dict = load_clip_text_state_dict(text_checkpoint_path, is_openai=True, skip_list=skip_list)
else:
text_state_dict = load_state_dict(
visual_checkpoint_path, model_key=model_key, is_openai=False, skip_list=skip_list
)
text_incompatible_keys = model.text.load_state_dict(text_state_dict, strict=strict)
logging.info(f"num of loaded text_state_dict keys: {len(text_state_dict.keys())}")
logging.info(f"text_incompatible_keys.missing_keys: {text_incompatible_keys.missing_keys}")
return visual_incompatible_keys, text_incompatible_keys
def create_model(
model_name: str,
pretrained: Optional[str] = None,
precision: str = "fp32",
device: Union[str, torch.device] = "cpu",
jit: bool = False,
force_quick_gelu: bool = False,
force_custom_clip: bool = False,
force_patch_dropout: Optional[float] = None,
pretrained_image: str = "",
pretrained_text: str = "",
pretrained_hf: bool = True,
pretrained_visual_model: str = None,
pretrained_text_model: str = None,
cache_dir: Optional[str] = None,
skip_list: list = [],
):
model_name = model_name.replace("/", "-") # for callers using old naming with / in ViT names
if isinstance(device, str):
device = torch.device(device)
if pretrained and pretrained.lower() == "openai":
pass
else:
model_cfg = get_model_config(model_name)
if model_cfg is not None:
logging.info(f"Loaded {model_name} model config.")
else:
logging.error(f"Model config for {model_name} not found; available models {list_models()}.")
raise RuntimeError(f"Model config for {model_name} not found.")
if "rope" in model_cfg.get("vision_cfg", {}):
if model_cfg["vision_cfg"]["rope"]:
os.environ["RoPE"] = "1"
else:
os.environ["RoPE"] = "0"
if force_quick_gelu:
# override for use of QuickGELU on non-OpenAI transformer models
model_cfg["quick_gelu"] = True
if force_patch_dropout is not None:
# override the default patch dropout value
model_cfg["vision_cfg"]["patch_dropout"] = force_patch_dropout
cast_dtype = get_cast_dtype(precision)
custom_clip = (
model_cfg.pop("custom_text", False) or force_custom_clip or ("hf_model_name" in model_cfg["text_cfg"])
)
if custom_clip:
if "hf_model_name" in model_cfg.get("text_cfg", {}):
model_cfg["text_cfg"]["hf_model_pretrained"] = pretrained_hf
model = CustomCLIP(**model_cfg, cast_dtype=cast_dtype)
else:
model = CLIP(**model_cfg, cast_dtype=cast_dtype)
pretrained_cfg = {}
if pretrained:
checkpoint_path = ""
pretrained_cfg = get_pretrained_cfg(model_name, pretrained)
if pretrained_cfg:
checkpoint_path = download_pretrained(pretrained_cfg, cache_dir=cache_dir)
elif os.path.exists(pretrained):
checkpoint_path = pretrained
if checkpoint_path:
logging.info(f"Loading pretrained {model_name} weights ({pretrained}).")
load_checkpoint(model, checkpoint_path, model_key="model|module|state_dict", strict=False)
else:
error_str = (
f"Pretrained weights ({pretrained}) not found for model {model_name}."
f"Available pretrained tags ({list_pretrained_tags_by_model(model_name)}."
)
logging.warning(error_str)
raise RuntimeError(error_str)
else:
visual_checkpoint_path = ""
text_checkpoint_path = ""
if pretrained_image:
pretrained_visual_model = pretrained_visual_model.replace(
"/", "-"
) # for callers using old naming with / in ViT names
pretrained_image_cfg = get_pretrained_cfg(pretrained_visual_model, pretrained_image)
if "timm_model_name" in model_cfg.get("vision_cfg", {}):
# pretrained weight loading for timm models set via vision_cfg
model_cfg["vision_cfg"]["timm_model_pretrained"] = True
elif pretrained_image_cfg:
visual_checkpoint_path = download_pretrained(pretrained_image_cfg, cache_dir=cache_dir)
elif os.path.exists(pretrained_image):
visual_checkpoint_path = pretrained_image
else:
logging.warning(
f"Pretrained weights ({visual_checkpoint_path}) not found for model {model_name}.visual."
)
raise RuntimeError(
f"Pretrained weights ({visual_checkpoint_path}) not found for model {model_name}.visual."
)
if pretrained_text:
pretrained_text_model = pretrained_text_model.replace(
"/", "-"
) # for callers using old naming with / in ViT names
pretrained_text_cfg = get_pretrained_cfg(pretrained_text_model, pretrained_text)
if pretrained_image_cfg:
text_checkpoint_path = download_pretrained(pretrained_text_cfg, cache_dir=cache_dir)
elif os.path.exists(pretrained_text):
text_checkpoint_path = pretrained_text
else:
logging.warning(
f"Pretrained weights ({text_checkpoint_path}) not found for model {model_name}.text."
)
raise RuntimeError(
f"Pretrained weights ({text_checkpoint_path}) not found for model {model_name}.text."
)
if visual_checkpoint_path:
logging.info(f"Loading pretrained {model_name}.visual weights ({visual_checkpoint_path}).")
if text_checkpoint_path:
logging.info(f"Loading pretrained {model_name}.text weights ({text_checkpoint_path}).")
if visual_checkpoint_path or text_checkpoint_path:
load_pretrained_checkpoint(
model,
visual_checkpoint_path,
text_checkpoint_path,
strict=False,
visual_model=pretrained_visual_model,
text_model=pretrained_text_model,
model_key="model|module|state_dict",
skip_list=skip_list,
)
if "fp16" in precision or "bf16" in precision:
logging.info(f"convert precision to {precision}")
model = model.to(torch.bfloat16) if "bf16" in precision else model.to(torch.float16)
model.to(device=device)
# set image / mean metadata from pretrained_cfg if available, or use default
model.visual.image_mean = pretrained_cfg.get("mean", None) or OPENAI_DATASET_MEAN
model.visual.image_std = pretrained_cfg.get("std", None) or OPENAI_DATASET_STD
if jit:
model = torch.jit.script(model)
return model
def create_model_and_transforms(
model_name: str,
pretrained: Optional[str] = None,
precision: str = "fp32",
device: Union[str, torch.device] = "cpu",
jit: bool = False,
force_quick_gelu: bool = False,
force_custom_clip: bool = False,
force_patch_dropout: Optional[float] = None,
pretrained_image: str = "",
pretrained_text: str = "",
pretrained_hf: bool = True,
pretrained_visual_model: str = None,
pretrained_text_model: str = None,
image_mean: Optional[Tuple[float, ...]] = None,
image_std: Optional[Tuple[float, ...]] = None,
cache_dir: Optional[str] = None,
skip_list: list = [],
):
model = create_model(
model_name,
pretrained,
precision=precision,
device=device,
jit=jit,
force_quick_gelu=force_quick_gelu,
force_custom_clip=force_custom_clip,
force_patch_dropout=force_patch_dropout,
pretrained_image=pretrained_image,
pretrained_text=pretrained_text,
pretrained_hf=pretrained_hf,
pretrained_visual_model=pretrained_visual_model,
pretrained_text_model=pretrained_text_model,
cache_dir=cache_dir,
skip_list=skip_list,
)
image_mean = image_mean or getattr(model.visual, "image_mean", None)
image_std = image_std or getattr(model.visual, "image_std", None)
preprocess_train = image_transform(model.visual.image_size, is_train=True, mean=image_mean, std=image_std)
preprocess_val = image_transform(model.visual.image_size, is_train=False, mean=image_mean, std=image_std)
return model, preprocess_train, preprocess_val
# HF architecture dict:
arch_dict = {
# https://huggingface.co/docs/transformers/model_doc/roberta#roberta
"roberta": {
"config_names": {
"context_length": "max_position_embeddings",
"vocab_size": "vocab_size",
"width": "hidden_size",
"heads": "num_attention_heads",
"layers": "num_hidden_layers",
"layer_attr": "layer",
"token_embeddings_attr": "embeddings",
},
"pooler": "mean_pooler",
},
# https://huggingface.co/docs/transformers/model_doc/xlm-roberta#transformers.XLMRobertaConfig
"xlm-roberta": {
"config_names": {
"context_length": "max_position_embeddings",
"vocab_size": "vocab_size",
"width": "hidden_size",
"heads": "num_attention_heads",
"layers": "num_hidden_layers",
"layer_attr": "layer",
"token_embeddings_attr": "embeddings",
},
"pooler": "mean_pooler",
},
# https://huggingface.co/docs/transformers/model_doc/mt5#mt5
"mt5": {
"config_names": {
# unlimited seqlen
# https://github.com/google-research/text-to-text-transfer-transformer/issues/273
# https://github.com/huggingface/transformers/blob/v4.24.0/src/transformers/models/t5/modeling_t5.py#L374
"context_length": "",
"vocab_size": "vocab_size",
"width": "d_model",
"heads": "num_heads",
"layers": "num_layers",
"layer_attr": "block",
"token_embeddings_attr": "embed_tokens",
},
"pooler": "mean_pooler",
},
"bert": {
"config_names": {
"context_length": "max_position_embeddings",
"vocab_size": "vocab_size",
"width": "hidden_size",
"heads": "num_attention_heads",
"layers": "num_hidden_layers",
"layer_attr": "layer",
"token_embeddings_attr": "embeddings",
},
"pooler": "mean_pooler",
},
}
"""huggingface model adapter
Wraps HuggingFace transformers (https://github.com/huggingface/transformers) models for use as a text tower in CLIP model.
"""
import re
import torch
import torch.nn as nn
from torch import TensorType
try:
import transformers
from transformers import AutoConfig, AutoModel, AutoModelForMaskedLM, AutoTokenizer, PretrainedConfig
except ImportError:
transformers = None
class PretrainedConfig:
pass
from .hf_configs import arch_dict
# utils
def _camel2snake(s):
return re.sub(r"(?<!^)(?=[A-Z])", "_", s).lower()
# TODO: ?last - for gpt-like models
_POOLERS = {}
class HFTextEncoder(nn.Module):
"""HuggingFace model adapter"""
def __init__(
self,
model_name_or_path: str,
output_dim: int,
tokenizer_name: str = None,
config: PretrainedConfig = None,
pooler_type: str = None,
proj: str = None,
pretrained: bool = True,
masked_language_modeling: bool = False,
):
super().__init__()
self.output_dim = output_dim
# TODO: find better way to get this information
uses_transformer_pooler = pooler_type == "cls_pooler"
if transformers is None:
raise RuntimeError("Please `pip install transformers` to use pre-trained HuggingFace models")
if config is None:
self.config = AutoConfig.from_pretrained(model_name_or_path)
if masked_language_modeling:
create_func, model_args = (
(AutoModelForMaskedLM.from_pretrained, model_name_or_path)
if pretrained
else (AutoModelForMaskedLM.from_config, self.config)
)
else:
create_func, model_args = (
(AutoModel.from_pretrained, model_name_or_path)
if pretrained
else (AutoModel.from_config, self.config)
)
# TODO: do all model configs have this attribute? PretrainedConfig does so yes??
if hasattr(self.config, "is_encoder_decoder") and self.config.is_encoder_decoder:
self.transformer = create_func(model_args)
self.transformer = self.transformer.encoder
else:
self.transformer = create_func(model_args, add_pooling_layer=uses_transformer_pooler)
else:
self.config = config
if masked_language_modeling:
self.transformer = AutoModelForMaskedLM.from_config(config)
else:
self.transformer = AutoModel.from_config(config)
if pooler_type is None: # get default arch pooler
self.pooler = _POOLERS[(arch_dict[self.config.model_type]["pooler"])]()
else:
self.pooler = _POOLERS[pooler_type]()
d_model = getattr(self.config, arch_dict[self.config.model_type]["config_names"]["width"])
if (d_model == output_dim) and (proj is None): # do we always need a proj?
self.proj = nn.Identity()
elif proj == "linear":
self.proj = nn.Linear(d_model, output_dim, bias=False)
elif proj == "mlp":
hidden_size = (d_model + output_dim) // 2
self.proj = nn.Sequential(
nn.Linear(d_model, hidden_size, bias=False),
nn.GELU(),
nn.Linear(hidden_size, output_dim, bias=False),
)
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
def mask(self, input_ids, vocab_size, device, targets=None, masked_indices=None, probability_matrix=None):
if masked_indices is None:
masked_indices = torch.bernoulli(probability_matrix).bool()
masked_indices[input_ids == self.tokenizer.pad_token_id] = False
masked_indices[input_ids == self.tokenizer.cls_token_id] = False
if targets is not None:
targets[~masked_indices] = -100 # We only compute loss on masked tokens
# 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
indices_replaced = torch.bernoulli(torch.full(input_ids.shape, 0.8)).bool() & masked_indices
input_ids[indices_replaced] = self.tokenizer.mask_token_id
# 10% of the time, we replace masked input tokens with random word
indices_random = torch.bernoulli(torch.full(input_ids.shape, 0.5)).bool() & masked_indices & ~indices_replaced
random_words = torch.randint(vocab_size, input_ids.shape, dtype=torch.long).to(device)
input_ids[indices_random] = random_words[indices_random]
# The rest of the time (10% of the time) we keep the masked input tokens unchanged
if targets is not None:
return input_ids, targets
else:
return input_ids
def forward(self, x: TensorType) -> TensorType:
attn_mask = (x != self.config.pad_token_id).long()
out = self.transformer(input_ids=x, attention_mask=attn_mask)
pooled_out = self.pooler(out, attn_mask)
return self.proj(pooled_out)
@torch.jit.ignore
def set_grad_checkpointing(self, enable=True):
self.transformer.gradient_checkpointing_enable()
def init_parameters(self):
pass
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment