# 下载jax和jaxlib ``` wget https://cancon.hpccube.com:65024/directlink/4/jax/DAS1.1/jaxlib-0.4.23+das1.1.git387bd43.abi1.dtk2404-cp39-cp39-manylinux_2_31_x86_64.whl wget https://cancon.hpccube.com:65024/directlink/4/jax/DAS1.1/jax-0.4.23+das1.1.git387bd43.abi1.dtk2404-py3-none-any.whl ``` # conda 环境 ``` conda create -n pymc3 python=3.9 conda activate pymc3 pip install jax-0.4.23+das1.1.git387bd43.abi1.dtk2404-py3-none-any.whl pip install jaxlib-0.4.23+das1.1.git387bd43.abi1.dtk2404-cp39-cp39-manylinux_2_31_x86_64.whl pip install pymc==5.9.1 pip install numpyro==0.14.0 pip install seaborn==0.13.2 pip install scipy==1.12.0 -i https://pypi.tuna.tsinghua.edu.cn/simple ``` 环境安装包如下: ``` arviz 0.17.1 cachetools 5.4.0 cloudpickle 3.0.0 cons 0.4.6 contourpy 1.2.1 cycler 0.12.1 etuples 0.3.9 fastprogress 1.0.3 filelock 3.15.4 fonttools 4.53.1 h5netcdf 1.3.0 h5py 3.11.0 importlib_metadata 8.2.0 importlib_resources 6.4.2 jax 0.4.23+das1.1.git387bd43.abi1.dtk2404 jaxlib 0.4.23+das1.1.git387bd43.abi1.dtk2404 kiwisolver 1.4.5 logical-unification 0.4.6 matplotlib 3.9.2 miniKanren 1.0.3 ml-dtypes 0.4.0 multipledispatch 1.0.0 numpy 1.26.4 numpyro 0.14.0 opt-einsum 3.3.0 packaging 24.1 pandas 2.2.2 pillow 10.4.0 pip 24.2 pymc 5.9.1 pyparsing 3.1.2 pytensor 2.17.4 python-dateutil 2.9.0.post0 pytz 2024.1 scipy 1.12.0 seaborn 0.13.2 setuptools 72.1.0 six 1.16.0 toolz 0.12.1 tqdm 4.66.5 typing_extensions 4.12.2 tzdata 2024.1 wheel 0.43.0 xarray 2024.7.0 xarray-einstats 0.7.0 zipp 3.20.0 ``` # 测试jax是否在gpu上可行: ``` import jax import pymc3 as pm jax.default_backend() jax.devices() ``` # 测试采样时间,GPU负载等情况 ``` import jax import pymc as pm import numpy as np import pytensor as pt pt.config.floatX = "float32" np.random.seed(123) n =10000 X = np.random.randn(n) Y =3* X + np.random.randn(n) # 定义PyMC3模型 with pm.Model() as model: alpha = pm.Normal('alpha', mu=0, sigma=10) beta = pm.Normal('beta', mu=0, sigma=10) sigma = pm.HalfNormal('sigma', sigma=1) mu = alpha + beta * X Y_obs = pm.Normal('Y_obs', mu=mu, sigma=sigma, observed=Y) trace = pm.sample(1000, nuts_sampler="numpyro",return_inferencedata=False) print(pm.summary(trace)) ```